@@ -929,7 +929,10 @@ def model_fn(features, ...):
929929 ValueError: if `initializer` is specified and is not callable.
930930 RuntimeError: If eager execution is enabled.
931931 """
932- if (dimension is None ) or (dimension < 1 ):
932+ if isinstance (categorical_column , MultiHashVariableCategoricalColumn ):
933+ if not isinstance (dimension , tuple ) or len (dimension ) != 2 :
934+ raise ValueError ('MultiHashVariable dimension error: {}.' .format (dimension ))
935+ elif (dimension is None ) or (dimension < 1 ):
933936 raise ValueError ('Invalid dimension {}.' .format (dimension ))
934937 if (ckpt_to_load_from is None ) != (tensor_name_in_ckpt is None ):
935938 raise ValueError ('Must specify both `ckpt_to_load_from` and '
@@ -2054,6 +2057,7 @@ def categorical_column_with_embedding(key,
20542057 ):
20552058 return EmbeddingCategoricalColumn (key , dtype , partition_num , ev_option )
20562059
2060+
20572061@tf_export ('feature_column.categorical_column_with_adaptive_embedding' )
20582062def categorical_column_with_adaptive_embedding (key ,
20592063 hash_bucket_size ,
@@ -2067,6 +2071,42 @@ def categorical_column_with_adaptive_embedding(key,
20672071 partition_num ,
20682072 ev_option )
20692073
2074+
2075+ @tf_export ('feature_column.categorical_column_with_multihash' )
2076+ def categorical_column_with_multihash (key ,
2077+ dims ,
2078+ complementary_strategy = "Q-R" ,
2079+ operation = "concat" ,
2080+ dtype = dtypes .int64 ,
2081+ partition_num = None ):
2082+ """A `CategoricalColumn` with a vocabulary file.
2083+ ......
2084+ Args:
2085+ key: A unique string identifying the input feature.
2086+ dims: A list which describe the shape of multi-hash table.
2087+ If complementary_strategy is "Q-R", the len(dims) must be 2.
2088+ complementary_strategy: now only can choose "Q-R".
2089+ operation: the operation for multi-hash table, which in
2090+ "add" or "mult" or "concat".
2091+ """
2092+ strategy_list = ["Q-R" ]
2093+ op_list = ["add" , "mul" , "concat" ]
2094+ num_of_partitions = len (dims )
2095+ if complementary_strategy not in strategy_list :
2096+ raise ValueError ("The strategy %s is not supported" % complementary_strategy )
2097+ if operation not in op_list :
2098+ raise ValueError ("The operation %s is not supported" % operation )
2099+ if complementary_strategy == 'Q-R' :
2100+ if num_of_partitions != 2 :
2101+ raise ValueError ("the num_of_partitions must be 2 when using Q-R strategy." )
2102+ return MultiHashVariableCategoricalColumn (key , dims ,
2103+ num_of_partitions ,
2104+ complementary_strategy ,
2105+ operation ,
2106+ dtype ,
2107+ partition_num )
2108+
2109+
20702110@tf_export ('feature_column.categorical_column_with_hash_bucket' )
20712111def categorical_column_with_hash_bucket (key ,
20722112 hash_bucket_size ,
@@ -4237,6 +4277,18 @@ def variable_shape(self):
42374277 _FEATURE_COLUMN_DEPRECATION )
42384278 def _variable_shape (self ):
42394279 return self .variable_shape
4280+
4281+ def _output_shape (self , inputs ):
4282+ """Tuple of column output shape"""
4283+ if isinstance (self .categorical_column , MultiHashVariableCategoricalColumn ):
4284+ batch_size = array_ops .shape (inputs )[0 ]
4285+ if self .categorical_column .operation == "concat" :
4286+ num_elements = self .dimension [0 ] + self .dimension [1 ]
4287+ else :
4288+ num_elements = self .dimension [0 ]
4289+ return (batch_size , num_elements )
4290+ else :
4291+ return super (EmbeddingColumn , self )._output_shape (inputs )
42404292
42414293 def create_state (self , state_manager ):
42424294 """Creates the embedding lookup variable."""
@@ -4312,6 +4364,22 @@ def _get_dense_tensor_internal_adaptive_helper(self, sparse_tensors,
43124364 max_norm = self .max_norm ,
43134365 adaptive_mask_tensor = self .categorical_column .adaptive_mask_tensor )
43144366
4367+ def _get_dense_tensor_internal_multihash_helper (self , sparse_tensors ,
4368+ embeddings_q , embeddings_r ):
4369+ if self .categorical_column .complementary_strategy == "Q-R" :
4370+ ids_q , ids_r = sparse_tensors .id_tensor
4371+ weight_q , weight_r = None , None if sparse_tensors .weight_tensor is None \
4372+ else sparse_tensors .weight_tensor
4373+ result_q = self ._get_dense_tensor_internal_helper (CategoricalColumn .IdWeightPair (ids_q , weight_q ),
4374+ embeddings_q )
4375+ result_r = self ._get_dense_tensor_internal_helper (CategoricalColumn .IdWeightPair (ids_r , weight_r ),
4376+ embeddings_r )
4377+ if self .categorical_column .operation == "add" :
4378+ return math_ops .add (result_q , result_r )
4379+ if self .categorical_column .operation == "mul" :
4380+ return math_ops .multiply (result_q , result_r )
4381+ if self .categorical_column .operation == "concat" :
4382+ return array_ops .concat ([result_q , result_r ], 1 )
43154383
43164384 def _get_dense_tensor_internal (self , sparse_tensors , state_manager ):
43174385 """Private method that follows the signature of get_dense_tensor."""
@@ -4333,6 +4401,7 @@ def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections,
43334401 weight_collections .append (ops .GraphKeys .GLOBAL_VARIABLES )
43344402 if isinstance (self .categorical_column , AdaptiveEmbeddingCategoricalColumn ) \
43354403 or isinstance (self .categorical_column , EmbeddingCategoricalColumn ) \
4404+ or isinstance (self .categorical_column , MultiHashVariableCategoricalColumn ) \
43364405 or is_sequence_embedding or is_weight_embedding :
43374406 if self .categorical_column .partition_num is None :
43384407 partitioner = None
@@ -4369,6 +4438,26 @@ def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections,
43694438 )
43704439 return self ._get_dense_tensor_internal_helper (sparse_tensors ,
43714440 embedding_weights )
4441+ elif isinstance (self .categorical_column , MultiHashVariableCategoricalColumn ):
4442+ embedding_weights_q = variable_scope .get_variable (
4443+ name = 'embedding_weights_q' ,
4444+ shape = (self .categorical_column .dims [0 ], self .dimension [0 ]),
4445+ dtype = dtypes .float32 ,
4446+ initializer = self .initializer ,
4447+ trainable = self .trainable and trainable ,
4448+ collections = weight_collections ,
4449+ partitioner = partitioner )
4450+ embedding_weights_r = variable_scope .get_variable (
4451+ name = 'embedding_weights_r' ,
4452+ shape = (self .categorical_column .dims [1 ], self .dimension [1 ]),
4453+ dtype = dtypes .float32 ,
4454+ initializer = self .initializer ,
4455+ trainable = self .trainable and trainable ,
4456+ collections = weight_collections ,
4457+ partitioner = partitioner )
4458+ return self ._get_dense_tensor_internal_multihash_helper (sparse_tensors ,
4459+ embedding_weights_q ,
4460+ embedding_weights_r )
43724461 else :
43734462 embedding_weights = variable_scope .get_variable (
43744463 name = 'embedding_weights' ,
@@ -6230,6 +6319,108 @@ def _from_config(cls, config, custom_objects=None, columns_by_name=None):
62306319 return cls (** kwargs )
62316320
62326321
6322+ class MultiHashVariableCategoricalColumn (
6323+ CategoricalColumn ,
6324+ fc_old ._CategoricalColumn , # pylint: disable=protected-access
6325+ collections .namedtuple ('MultiHashVariableCategoricalColumn' ,
6326+ ('key' , 'dims' , 'num_of_partitions' ,
6327+ 'complementary_strategy' , 'operation' , 'dtype' ,
6328+ 'partition_num' ))):
6329+
6330+ @property
6331+ def _is_v2_column (self ):
6332+ return True
6333+
6334+ @property
6335+ def name (self ):
6336+ """See `FeatureColumn` base class."""
6337+ return self .key
6338+
6339+ @property
6340+ def parse_example_spec (self ):
6341+ """See `FeatureColumn` base class."""
6342+ return {self .key : parsing_ops .VarLenFeature (self .dtype )}
6343+
6344+ @property
6345+ @deprecation .deprecated (_FEATURE_COLUMN_DEPRECATION_DATE ,
6346+ _FEATURE_COLUMN_DEPRECATION )
6347+ def _parse_example_spec (self ):
6348+ return self .parse_example_spec
6349+
6350+ def _transform_input_tensor (self , input_tensor ):
6351+ """Transform the values in the feature_column."""
6352+ if not isinstance (input_tensor , sparse_tensor_lib .SparseTensor ):
6353+ raise ValueError ('SparseColumn input must be a SparseTensor.' )
6354+
6355+ if input_tensor .dtype .is_integer != True :
6356+ raise ValueError ('Input type must be a integer.' )
6357+
6358+ sparse_id_values = input_tensor .values
6359+ if self .complementary_strategy == "Q-R" :
6360+ ids_q = math_ops .floordiv (sparse_id_values , self .dims [0 ])
6361+ ids_r = math_ops .floormod (sparse_id_values , self .dims [1 ])
6362+ sparse_tensor_q = sparse_tensor_lib .SparseTensor (
6363+ input_tensor .indices , ids_q , input_tensor .dense_shape )
6364+ sparse_tensor_r = sparse_tensor_lib .SparseTensor (
6365+ input_tensor .indices , ids_r , input_tensor .dense_shape )
6366+ return (sparse_tensor_q , sparse_tensor_r )
6367+
6368+ def transform_feature (self , transformation_cache , state_manager ):
6369+ """Hashes the values in the feature_column."""
6370+ input_tensor = _to_sparse_input_and_drop_ignore_values (
6371+ transformation_cache .get (self .key , state_manager ))
6372+ return self ._transform_input_tensor (input_tensor )
6373+
6374+ @deprecation .deprecated (_FEATURE_COLUMN_DEPRECATION_DATE ,
6375+ _FEATURE_COLUMN_DEPRECATION )
6376+ def _transform_feature (self , inputs ):
6377+ input_tensor = _to_sparse_input_and_drop_ignore_values (inputs .get (self .key ))
6378+ return self ._transform_input_tensor (input_tensor )
6379+
6380+ @property
6381+ def num_buckets (self ):
6382+ """Returns number of buckets in this sparse feature."""
6383+ return self .dims
6384+
6385+ @property
6386+ @deprecation .deprecated (_FEATURE_COLUMN_DEPRECATION_DATE ,
6387+ _FEATURE_COLUMN_DEPRECATION )
6388+ def _num_buckets (self ):
6389+ return self .num_buckets
6390+
6391+ def get_sparse_tensors (self , transformation_cache , state_manager ):
6392+ """See `CategoricalColumn` base class."""
6393+ return CategoricalColumn .IdWeightPair (
6394+ transformation_cache .get (self , state_manager ), None )
6395+
6396+ @deprecation .deprecated (_FEATURE_COLUMN_DEPRECATION_DATE ,
6397+ _FEATURE_COLUMN_DEPRECATION )
6398+ def _get_sparse_tensors (self , inputs , weight_collections = None ,
6399+ trainable = None ):
6400+ del weight_collections
6401+ del trainable
6402+ return CategoricalColumn .IdWeightPair (inputs .get (self ), None )
6403+
6404+ @property
6405+ def parents (self ):
6406+ """See 'FeatureColumn` base class."""
6407+ return [self .key ]
6408+
6409+ def _get_config (self ):
6410+ """See 'FeatureColumn` base class."""
6411+ config = dict (zip (self ._fields , self ))
6412+ config ['dtype' ] = self .dtype .name
6413+ return config
6414+
6415+ @classmethod
6416+ def _from_config (cls , config , custom_objects = None , columns_by_name = None ):
6417+ """See 'FeatureColumn` base class."""
6418+ _check_config_keys (config , cls ._fields )
6419+ kwargs = _standardize_and_copy_config (config )
6420+ kwargs ['dtype' ] = dtypes .as_dtype (config ['dtype' ])
6421+ return cls (** kwargs )
6422+
6423+
62336424class HashOnlyCategoricalColumn (
62346425 CategoricalColumn ,
62356426 fc_old ._CategoricalColumn , # pylint: disable=protected-access
0 commit comments