Skip to content

Commit 0dda3aa

Browse files
authored
[Embedding] Support feature column for MultiHash. (#365)
1 parent 1895f76 commit 0dda3aa

File tree

3 files changed

+328
-132
lines changed

3 files changed

+328
-132
lines changed

docs/Multi-Hash-Variable.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def get_multihash_variable(name,
4747
4848

4949
**使用示例:**
50+
使用`get_multihash_variable`接口
5051
```python
5152
import tensorflow as tf
5253

@@ -67,7 +68,38 @@ def main(unused_argv):
6768
with tf.Session() as sess:
6869
sess.run([init])
6970
print(sess.run([var, train_op]))
71+
print(sess.run([var, train_op]))
72+
print(sess.run([var, train_op]))
7073

7174
if __name__=="__main__":
7275
tf.app.run()
7376
```
77+
使用`categorical_column_with_multihash`接口
78+
```python
79+
import tensorflow as tf
80+
from tensorflow.python.framework import ops
81+
from tensorflow.python.feature_column import feature_column_v2 as fc2
82+
83+
columns = fc2.categorical_column_with_multihash("col_emb", dims = (2,2))
84+
W = tf.feature_column.embedding_column(categorical_column=columns,
85+
dimension=(2,3),
86+
initializer=tf.ones_initializer(tf.dtypes.float32))
87+
88+
ids={}
89+
ids["col_emb"] = tf.SparseTensor(indices=[[0,0],[1,1],[2,2],[3,3]], values=tf.cast([0,1,2,3], tf.dtypes.int64), dense_shape=[4, 4])
90+
91+
emb = tf.feature_column.input_layer(ids, [W])
92+
fun = tf.multiply(emb, 2.0, name='multiply')
93+
loss1 = tf.reduce_sum(fun, name='reduce_sum')
94+
opt = tf.train.AdagradOptimizer(0.1)
95+
g_v = opt.compute_gradients(loss1)
96+
train_op = opt.apply_gradients(g_v)
97+
init = tf.global_variables_initializer()
98+
99+
with tf.Session() as sess:
100+
sess.run(init)
101+
print("init global done")
102+
print(sess.run([emb, train_op]))
103+
print(sess.run([emb, train_op]))
104+
print(sess.run([emb, train_op]))
105+
```

tensorflow/python/feature_column/feature_column_v2.py

Lines changed: 192 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,10 @@ def model_fn(features, ...):
929929
ValueError: if `initializer` is specified and is not callable.
930930
RuntimeError: If eager execution is enabled.
931931
"""
932-
if (dimension is None) or (dimension < 1):
932+
if isinstance(categorical_column, MultiHashVariableCategoricalColumn):
933+
if not isinstance(dimension, tuple) or len(dimension) != 2:
934+
raise ValueError('MultiHashVariable dimension error: {}.'.format(dimension))
935+
elif (dimension is None) or (dimension < 1):
933936
raise ValueError('Invalid dimension {}.'.format(dimension))
934937
if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
935938
raise ValueError('Must specify both `ckpt_to_load_from` and '
@@ -2054,6 +2057,7 @@ def categorical_column_with_embedding(key,
20542057
):
20552058
return EmbeddingCategoricalColumn(key, dtype, partition_num, ev_option)
20562059

2060+
20572061
@tf_export('feature_column.categorical_column_with_adaptive_embedding')
20582062
def categorical_column_with_adaptive_embedding(key,
20592063
hash_bucket_size,
@@ -2067,6 +2071,42 @@ def categorical_column_with_adaptive_embedding(key,
20672071
partition_num,
20682072
ev_option)
20692073

2074+
2075+
@tf_export('feature_column.categorical_column_with_multihash')
2076+
def categorical_column_with_multihash(key,
2077+
dims,
2078+
complementary_strategy="Q-R",
2079+
operation="concat",
2080+
dtype=dtypes.int64,
2081+
partition_num=None):
2082+
"""A `CategoricalColumn` with a vocabulary file.
2083+
......
2084+
Args:
2085+
key: A unique string identifying the input feature.
2086+
dims: A list which describe the shape of multi-hash table.
2087+
If complementary_strategy is "Q-R", the len(dims) must be 2.
2088+
complementary_strategy: now only can choose "Q-R".
2089+
operation: the operation for multi-hash table, which in
2090+
"add" or "mult" or "concat".
2091+
"""
2092+
strategy_list = ["Q-R"]
2093+
op_list = ["add", "mul", "concat"]
2094+
num_of_partitions = len(dims)
2095+
if complementary_strategy not in strategy_list:
2096+
raise ValueError("The strategy %s is not supported" % complementary_strategy)
2097+
if operation not in op_list:
2098+
raise ValueError("The operation %s is not supported" % operation)
2099+
if complementary_strategy == 'Q-R':
2100+
if num_of_partitions != 2:
2101+
raise ValueError("the num_of_partitions must be 2 when using Q-R strategy.")
2102+
return MultiHashVariableCategoricalColumn(key, dims,
2103+
num_of_partitions,
2104+
complementary_strategy,
2105+
operation,
2106+
dtype,
2107+
partition_num)
2108+
2109+
20702110
@tf_export('feature_column.categorical_column_with_hash_bucket')
20712111
def categorical_column_with_hash_bucket(key,
20722112
hash_bucket_size,
@@ -4237,6 +4277,18 @@ def variable_shape(self):
42374277
_FEATURE_COLUMN_DEPRECATION)
42384278
def _variable_shape(self):
42394279
return self.variable_shape
4280+
4281+
def _output_shape(self, inputs):
4282+
"""Tuple of column output shape"""
4283+
if isinstance(self.categorical_column, MultiHashVariableCategoricalColumn):
4284+
batch_size = array_ops.shape(inputs)[0]
4285+
if self.categorical_column.operation == "concat":
4286+
num_elements = self.dimension[0] + self.dimension[1]
4287+
else:
4288+
num_elements = self.dimension[0]
4289+
return (batch_size, num_elements)
4290+
else:
4291+
return super(EmbeddingColumn, self)._output_shape(inputs)
42404292

42414293
def create_state(self, state_manager):
42424294
"""Creates the embedding lookup variable."""
@@ -4312,6 +4364,22 @@ def _get_dense_tensor_internal_adaptive_helper(self, sparse_tensors,
43124364
max_norm=self.max_norm,
43134365
adaptive_mask_tensor=self.categorical_column.adaptive_mask_tensor)
43144366

4367+
def _get_dense_tensor_internal_multihash_helper(self, sparse_tensors,
4368+
embeddings_q, embeddings_r):
4369+
if self.categorical_column.complementary_strategy == "Q-R":
4370+
ids_q, ids_r = sparse_tensors.id_tensor
4371+
weight_q, weight_r = None, None if sparse_tensors.weight_tensor is None \
4372+
else sparse_tensors.weight_tensor
4373+
result_q = self._get_dense_tensor_internal_helper(CategoricalColumn.IdWeightPair(ids_q, weight_q),
4374+
embeddings_q)
4375+
result_r = self._get_dense_tensor_internal_helper(CategoricalColumn.IdWeightPair(ids_r, weight_r),
4376+
embeddings_r)
4377+
if self.categorical_column.operation == "add":
4378+
return math_ops.add(result_q, result_r)
4379+
if self.categorical_column.operation == "mul":
4380+
return math_ops.multiply(result_q, result_r)
4381+
if self.categorical_column.operation == "concat":
4382+
return array_ops.concat([result_q, result_r], 1)
43154383

43164384
def _get_dense_tensor_internal(self, sparse_tensors, state_manager):
43174385
"""Private method that follows the signature of get_dense_tensor."""
@@ -4333,6 +4401,7 @@ def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections,
43334401
weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
43344402
if isinstance(self.categorical_column, AdaptiveEmbeddingCategoricalColumn) \
43354403
or isinstance(self.categorical_column, EmbeddingCategoricalColumn) \
4404+
or isinstance(self.categorical_column, MultiHashVariableCategoricalColumn) \
43364405
or is_sequence_embedding or is_weight_embedding:
43374406
if self.categorical_column.partition_num is None:
43384407
partitioner = None
@@ -4369,6 +4438,26 @@ def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections,
43694438
)
43704439
return self._get_dense_tensor_internal_helper(sparse_tensors,
43714440
embedding_weights)
4441+
elif isinstance(self.categorical_column, MultiHashVariableCategoricalColumn):
4442+
embedding_weights_q = variable_scope.get_variable(
4443+
name='embedding_weights_q',
4444+
shape=(self.categorical_column.dims[0], self.dimension[0]),
4445+
dtype=dtypes.float32,
4446+
initializer=self.initializer,
4447+
trainable=self.trainable and trainable,
4448+
collections=weight_collections,
4449+
partitioner=partitioner)
4450+
embedding_weights_r = variable_scope.get_variable(
4451+
name='embedding_weights_r',
4452+
shape=(self.categorical_column.dims[1], self.dimension[1]),
4453+
dtype=dtypes.float32,
4454+
initializer=self.initializer,
4455+
trainable=self.trainable and trainable,
4456+
collections=weight_collections,
4457+
partitioner=partitioner)
4458+
return self._get_dense_tensor_internal_multihash_helper(sparse_tensors,
4459+
embedding_weights_q,
4460+
embedding_weights_r)
43724461
else:
43734462
embedding_weights = variable_scope.get_variable(
43744463
name='embedding_weights',
@@ -6230,6 +6319,108 @@ def _from_config(cls, config, custom_objects=None, columns_by_name=None):
62306319
return cls(**kwargs)
62316320

62326321

6322+
class MultiHashVariableCategoricalColumn(
6323+
CategoricalColumn,
6324+
fc_old._CategoricalColumn, # pylint: disable=protected-access
6325+
collections.namedtuple('MultiHashVariableCategoricalColumn',
6326+
('key', 'dims', 'num_of_partitions',
6327+
'complementary_strategy', 'operation', 'dtype',
6328+
'partition_num'))):
6329+
6330+
@property
6331+
def _is_v2_column(self):
6332+
return True
6333+
6334+
@property
6335+
def name(self):
6336+
"""See `FeatureColumn` base class."""
6337+
return self.key
6338+
6339+
@property
6340+
def parse_example_spec(self):
6341+
"""See `FeatureColumn` base class."""
6342+
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
6343+
6344+
@property
6345+
@deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
6346+
_FEATURE_COLUMN_DEPRECATION)
6347+
def _parse_example_spec(self):
6348+
return self.parse_example_spec
6349+
6350+
def _transform_input_tensor(self, input_tensor):
6351+
"""Transform the values in the feature_column."""
6352+
if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
6353+
raise ValueError('SparseColumn input must be a SparseTensor.')
6354+
6355+
if input_tensor.dtype.is_integer != True:
6356+
raise ValueError('Input type must be a integer.')
6357+
6358+
sparse_id_values = input_tensor.values
6359+
if self.complementary_strategy == "Q-R":
6360+
ids_q = math_ops.floordiv(sparse_id_values, self.dims[0])
6361+
ids_r = math_ops.floormod(sparse_id_values, self.dims[1])
6362+
sparse_tensor_q = sparse_tensor_lib.SparseTensor(
6363+
input_tensor.indices, ids_q, input_tensor.dense_shape)
6364+
sparse_tensor_r = sparse_tensor_lib.SparseTensor(
6365+
input_tensor.indices, ids_r, input_tensor.dense_shape)
6366+
return (sparse_tensor_q, sparse_tensor_r)
6367+
6368+
def transform_feature(self, transformation_cache, state_manager):
6369+
"""Hashes the values in the feature_column."""
6370+
input_tensor = _to_sparse_input_and_drop_ignore_values(
6371+
transformation_cache.get(self.key, state_manager))
6372+
return self._transform_input_tensor(input_tensor)
6373+
6374+
@deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
6375+
_FEATURE_COLUMN_DEPRECATION)
6376+
def _transform_feature(self, inputs):
6377+
input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
6378+
return self._transform_input_tensor(input_tensor)
6379+
6380+
@property
6381+
def num_buckets(self):
6382+
"""Returns number of buckets in this sparse feature."""
6383+
return self.dims
6384+
6385+
@property
6386+
@deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
6387+
_FEATURE_COLUMN_DEPRECATION)
6388+
def _num_buckets(self):
6389+
return self.num_buckets
6390+
6391+
def get_sparse_tensors(self, transformation_cache, state_manager):
6392+
"""See `CategoricalColumn` base class."""
6393+
return CategoricalColumn.IdWeightPair(
6394+
transformation_cache.get(self, state_manager), None)
6395+
6396+
@deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
6397+
_FEATURE_COLUMN_DEPRECATION)
6398+
def _get_sparse_tensors(self, inputs, weight_collections=None,
6399+
trainable=None):
6400+
del weight_collections
6401+
del trainable
6402+
return CategoricalColumn.IdWeightPair(inputs.get(self), None)
6403+
6404+
@property
6405+
def parents(self):
6406+
"""See 'FeatureColumn` base class."""
6407+
return [self.key]
6408+
6409+
def _get_config(self):
6410+
"""See 'FeatureColumn` base class."""
6411+
config = dict(zip(self._fields, self))
6412+
config['dtype'] = self.dtype.name
6413+
return config
6414+
6415+
@classmethod
6416+
def _from_config(cls, config, custom_objects=None, columns_by_name=None):
6417+
"""See 'FeatureColumn` base class."""
6418+
_check_config_keys(config, cls._fields)
6419+
kwargs = _standardize_and_copy_config(config)
6420+
kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
6421+
return cls(**kwargs)
6422+
6423+
62336424
class HashOnlyCategoricalColumn(
62346425
CategoricalColumn,
62356426
fc_old._CategoricalColumn, # pylint: disable=protected-access

0 commit comments

Comments
 (0)