Skip to content

Commit fe194b0

Browse files
authored
[Embedding] Fix incorrect frequency in shared-embedding. (#931)
Signed-off-by: lixy9474 <[email protected]>
1 parent 821d5e8 commit fe194b0

File tree

4 files changed

+115
-8
lines changed

4 files changed

+115
-8
lines changed

tensorflow/python/ops/embedding_variable_ops_test.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2816,5 +2816,79 @@ def testSetInitializedWithRestore(self):
28162816
result = sess.run(var._is_initialized_op)
28172817
self.assertEqual(True, result)
28182818

2819+
def testCountsTensor(self):
2820+
os.environ["TF_RECORD_FREQ"] = "1"
2821+
checkpoint_directory = self.get_temp_dir()
2822+
ckpt_path = os.path.join(checkpoint_directory, "model.ckpt")
2823+
with ops.Graph().as_default() as g, ops.device('/cpu:0'):
2824+
var = variable_scope.get_embedding_variable("var_1",
2825+
embedding_dim = 3)
2826+
sp1 = sparse_tensor.SparseTensor(
2827+
indices=[[0,0],[1,0],[2,0],[3,0],[4,0],[5,0]],
2828+
values=math_ops.cast([0,0,0,1,1,2], dtypes.int64),
2829+
dense_shape=[6, 1])
2830+
sp2 = sparse_tensor.SparseTensor(
2831+
indices=[[0,0],[1,0],[2,0],[3,0],[4,0],[5,0]],
2832+
values=math_ops.cast([3,3,3,4,4,1], dtypes.int64),
2833+
dense_shape=[6, 1])
2834+
emb1 = embedding_ops.embedding_lookup_sparse(var, sp1, None)
2835+
emb2 = embedding_ops.embedding_lookup_sparse(var, sp2, None)
2836+
emb = emb1 + emb2
2837+
fun = math_ops.multiply(emb, 2.0, name='multiply')
2838+
loss = math_ops.reduce_sum(fun, name='reduce_sum')
2839+
gs = training_util.get_or_create_global_step()
2840+
opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs)
2841+
g_v = opt.compute_gradients(loss)
2842+
train_op = opt.apply_gradients(g_v)
2843+
saver = saver_module.Saver()
2844+
init = variables.global_variables_initializer()
2845+
with self.test_session(graph=g) as sess:
2846+
sess.run([init])
2847+
sess.run(train_op)
2848+
saver.save(sess, ckpt_path)
2849+
2850+
for name, shape in checkpoint_utils.list_variables(ckpt_path):
2851+
if name == "var_1-freqs":
2852+
value = checkpoint_utils.load_variable(ckpt_path, name)
2853+
self.assertAllEqual(value, [3, 3, 1, 3, 2])
2854+
2855+
def testCountsTensorWithGradientDescent(self):
2856+
os.environ["TF_RECORD_FREQ"] = "1"
2857+
checkpoint_directory = self.get_temp_dir()
2858+
ckpt_path = os.path.join(checkpoint_directory, "model.ckpt")
2859+
with ops.Graph().as_default() as g, ops.device('/cpu:0'):
2860+
var = variable_scope.get_embedding_variable("var_1",
2861+
embedding_dim = 3)
2862+
sp1 = sparse_tensor.SparseTensor(
2863+
indices=[[0,0],[1,0],[2,0],[3,0],[4,0],[5,0]],
2864+
values=math_ops.cast([0,0,0,1,1,2], dtypes.int64),
2865+
dense_shape=[6, 1])
2866+
sp2 = sparse_tensor.SparseTensor(
2867+
indices=[[0,0],[1,0],[2,0],[3,0],[4,0],[5,0]],
2868+
values=math_ops.cast([3,3,3,4,4,1], dtypes.int64),
2869+
dense_shape=[6, 1])
2870+
emb1 = embedding_ops.embedding_lookup_sparse(var, sp1, None)
2871+
emb2 = embedding_ops.embedding_lookup_sparse(var, sp2, None)
2872+
emb = emb1 + emb2
2873+
fun = math_ops.multiply(emb, 2.0, name='multiply')
2874+
loss = math_ops.reduce_sum(fun, name='reduce_sum')
2875+
gs = training_util.get_or_create_global_step()
2876+
opt = gradient_descent.GradientDescentOptimizer(0.1)
2877+
g_v = opt.compute_gradients(loss)
2878+
train_op = opt.apply_gradients(g_v)
2879+
saver = saver_module.Saver()
2880+
init = variables.global_variables_initializer()
2881+
with self.test_session(graph=g) as sess:
2882+
sess.run([init])
2883+
sess.run(train_op)
2884+
saver.save(sess, ckpt_path)
2885+
2886+
for name, shape in checkpoint_utils.list_variables(ckpt_path):
2887+
if name == "var_1-freqs":
2888+
value = checkpoint_utils.load_variable(ckpt_path, name)
2889+
self.assertAllEqual(value, [3, 3, 1, 3, 2])
2890+
2891+
del os.environ["TF_RECORD_FREQ"]
2892+
28192893
if __name__ == "__main__":
28202894
googletest.main()

tensorflow/python/ops/kv_variable_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def _init_from_args(self,
368368
self._dtype = initial_value.dtype.base_dtype
369369
self._constraint = constraint
370370
self._gather_op = None
371-
self._counts_tensor = None
371+
self._counts_tensor = {}
372372
if self._is_primary:
373373
self._slot_num = 0
374374
else:
@@ -850,7 +850,7 @@ def sparse_read(self, indices, name=None, ev_init_value=None, counts=None):
850850
default_value,
851851
counts, is_inference=True,
852852
name=name)
853-
self._counts_tensor = counts
853+
self._counts_tensor[indices] = counts
854854
else:
855855
value = gen_kv_variable_ops.kv_resource_gather(self._handle,
856856
indices,

tensorflow/python/training/gradient_descent.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,23 @@ def _resource_apply_dense(self, grad, handle):
7171
def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices):
7272
if isinstance(handle, kv_variable_ops.EmbeddingVariable):
7373
global_step = training_util.get_or_create_global_step()
74-
if handle.need_counts() and handle._counts_tensor is not None:
74+
if handle.need_counts() and len(handle._counts_tensor.keys()) != 0:
75+
if indices.op.type == "ConcatV2":
76+
total_counts = []
77+
for tensor in indices.op.inputs:
78+
if tensor.op.type == "Reshape":
79+
indices_tensor = tensor.op.inputs[0]
80+
total_counts.append(handle._counts_tensor[indices_tensor])
81+
from tensorflow.python.ops import array_ops
82+
counts_tensor = array_ops.concat(total_counts, 0)
83+
elif indices.op.type == "Reshape":
84+
indices_tensor = indices.op.inputs[0]
85+
counts_tensor = handle._counts_tensor[indices_tensor]
7586
return training_ops.kv_resource_sparse_apply_gradient_descent_with_counts(
7687
handle.handle, math_ops.cast(self._learning_rate_tensor,
7788
grad.dtype.base_dtype),
7889
grad, indices, global_step,
79-
handle._counts_tensor, use_locking=self._use_locking)
90+
counts_tensor, use_locking=self._use_locking)
8091
else:
8192
return training_ops.kv_resource_sparse_apply_gradient_descent(
8293
handle.handle, math_ops.cast(self._learning_rate_tensor,

tensorflow/python/training/optimizer.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,18 @@ def _deduplicate_indexed_slices_with_counts(values, indices):
9393
array_ops.shape(unique_indices)[0])
9494
return (summed_values, unique_indices, indices_counts)
9595

96+
def _deduplicate_indexed_slices_with_counts_reduction(values, indices, counts):
97+
"""Sums `values` associated with any non-unique `indices`
98+
and return counts of each count in `values`."""
99+
unique_indices, new_index_positions = array_ops.unique(indices)
100+
summed_values = math_ops.unsorted_segment_sum(
101+
values, new_index_positions,
102+
array_ops.shape(unique_indices)[0])
103+
summed_counts = math_ops.unsorted_segment_sum(
104+
counts, new_index_positions,
105+
array_ops.shape(unique_indices)[0])
106+
return (summed_values, unique_indices, summed_counts)
107+
96108
def _var_key(var):
97109
# TODO(ashankar): Consolidate handling for eager and graph
98110
if hasattr(var, "op"):
@@ -1088,14 +1100,24 @@ def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices):
10881100
"""
10891101
from tensorflow.python.ops import kv_variable_ops
10901102
if isinstance(handle, kv_variable_ops.EmbeddingVariable) and handle.need_counts():
1091-
if handle._counts_tensor is None:
1103+
if len(handle._counts_tensor.keys()) == 0:
10921104
summed_grad, unique_indices, indices_counts = \
10931105
_deduplicate_indexed_slices_with_counts(
10941106
values=grad, indices=indices)
10951107
else:
1096-
summed_grad, unique_indices = _deduplicate_indexed_slices(
1097-
values=grad, indices=indices)
1098-
indices_counts = handle._counts_tensor
1108+
if indices.op.type == "ConcatV2":
1109+
total_counts = []
1110+
for tensor in indices.op.inputs:
1111+
if tensor.op.type == "Reshape":
1112+
indices_tensor = tensor.op.inputs[0]
1113+
total_counts.append(handle._counts_tensor[indices_tensor])
1114+
counts_tensor = array_ops.concat(total_counts, 0)
1115+
elif indices.op.type == "Reshape":
1116+
indices_tensor = indices.op.inputs[0]
1117+
counts_tensor = handle._counts_tensor[indices_tensor]
1118+
summed_grad, unique_indices, indices_counts = \
1119+
_deduplicate_indexed_slices_with_counts_reduction(
1120+
grad, indices, counts_tensor)
10991121
return self._resource_apply_sparse(
11001122
summed_grad, handle, unique_indices, indices_counts)
11011123
else:

0 commit comments

Comments
 (0)