Skip to content

Commit 798e991

Browse files
authored
[Graph] Support asynchronous embedding lookup. (#376)
1 parent 8bc8026 commit 798e991

File tree

13 files changed

+539
-9
lines changed

13 files changed

+539
-9
lines changed

tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ def sequence_input_layer(
131131
fc._verify_static_batch_size_equality(sequence_lengths, ordered_columns)
132132
sequence_length = _assert_all_equal_and_return(sequence_lengths)
133133

134-
return array_ops.concat(output_tensors, -1), sequence_length
134+
concat_result = array_ops.concat(output_tensors, -1)
135+
ops.add_to_collection(ops.GraphKeys.ASYNC_EMBEDDING_OUTPUT_TENSORS, concat_result)
136+
return concat_result, sequence_length
135137

136138

137139
def concatenate_context_input(context_input, sequence_input):

tensorflow/contrib/layers/python/layers/embedding_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def safe_embedding_lookup_sparse(embedding_weights,
186186
final_result.set_shape(
187187
tensor_shape.unknown_shape(
188188
(original_rank_dim - 1).value).concatenate(result.get_shape()[1:]))
189+
ops.add_to_collections(ops.GraphKeys.ASYNC_EMBEDDING_OUTPUT_TENSORS, final_result)
189190
return final_result
190191

191192
def fused_safe_embedding_lookup_sparse(embedding_weights,

tensorflow/contrib/layers/python/layers/feature_column_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,9 @@ def _input_from_feature_columns(columns_to_tensors,
164164
'{}, {}'.format(column.name, e, ee))
165165
if cols_to_outs is not None:
166166
cols_to_outs[column] = output_tensors[-1]
167+
ops.add_to_collections(ops.GraphKeys.ASYNC_EMBEDDING_OUTPUT_TENSORS, output_tensors[-1])
167168
return array_ops.concat(output_tensors, output_rank - 1)
168169

169-
170170
def input_from_feature_columns(columns_to_tensors,
171171
feature_columns,
172172
weight_collections=None,

tensorflow/core/common_runtime/graph_execution_state.cc

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -763,14 +763,19 @@ Status GraphExecutionState::InitBaseGraph(std::unique_ptr<Graph>&& new_graph) {
763763
if (session_optimizer_options.do_smart_stage() ||
764764
session_optimizer_options.do_smart_stage_gpu()) {
765765
VLOG(2) << "RUN Graph Optimization: SmartStage";
766-
std::string tn;
767-
ReadStringFromEnvVar("TARGET_NODES_NAME", "", &tn);
768-
std::vector<std::string> target_nodes;
769-
for (std::string s : str_util::Split(tn, ';')) {
770-
target_nodes.push_back(s.substr(0, s.find_last_of(':')));
766+
767+
if (session_optimizer_options.do_async_embedding()) {
768+
VLOG(0) << "Async Embedding is enable, disable SmartStage";
769+
} else {
770+
std::string tn;
771+
ReadStringFromEnvVar("TARGET_NODES_NAME", "", &tn);
772+
std::vector<std::string> target_nodes;
773+
for (std::string s : str_util::Split(tn, ';')) {
774+
target_nodes.push_back(s.substr(0, s.find_last_of(':')));
775+
}
776+
SmartStageGraph(&new_graph, target_nodes,
777+
session_optimizer_options.do_smart_stage_gpu());
771778
}
772-
SmartStageGraph(&new_graph, target_nodes,
773-
session_optimizer_options.do_smart_stage_gpu());
774779
}
775780

776781
SaveStatefulNodes(new_graph.get());

tensorflow/core/protobuf/config.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ message OptimizerOptions {
256256
int32 micro_batch_num = 9;
257257
bool do_smart_stage = 10;
258258
bool do_smart_stage_gpu = 11;
259+
bool do_async_embedding = 12;
260+
int32 async_embedding_threads_num = 13;
261+
int32 async_embedding_capacity = 14;
259262
}
260263

261264
message GraphOptions {

tensorflow/python/BUILD

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4641,6 +4641,8 @@ py_library(
46414641
":util",
46424642
":variable_scope",
46434643
":variables",
4644+
":prefetch",
4645+
":prefetch_runner",
46444646
"//tensorflow/core:protos_all_py",
46454647
"//tensorflow/python/data/ops:dataset_ops",
46464648
"//tensorflow/python/distribute:distribute_coordinator_context",
@@ -6257,6 +6259,20 @@ tf_py_test(
62576259
],
62586260
)
62596261

6262+
tf_py_test(
6263+
name = "async_embedding_stage_test",
6264+
size = "small",
6265+
srcs = ["training/async_embedding_stage_test.py"],
6266+
additional_deps = [
6267+
":training",
6268+
":prefetch",
6269+
":prefetch_runner",
6270+
":variables",
6271+
":math_ops",
6272+
"framework",
6273+
],
6274+
)
6275+
62606276
py_library(
62616277
name = "training_util",
62626278
srcs = ["training/training_util.py"],

tensorflow/python/feature_column/feature_column.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def _get_logits(): # pylint: disable=missing-docstring
218218
scope=variable_scope.get_variable_scope().name)
219219
if cols_to_output_tensors is not None:
220220
cols_to_output_tensors[column] = output_tensor
221+
ops.add_to_collections(ops.GraphKeys.ASYNC_EMBEDDING_OUTPUT_TENSORS, output_tensor)
221222
_verify_static_batch_size_equality(output_tensors, ordered_columns)
222223
return array_ops.concat(output_tensors, -1)
223224

tensorflow/python/framework/ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6102,6 +6102,9 @@ class GraphKeys(object):
61026102
EV_INIT_VAR_OPS = "ev_init_var_ops"
61036103
EV_INIT_SLOT_OPS = "ev_init_slot_ops"
61046104

6105+
# Key to collect embedding lookup output result.
6106+
ASYNC_EMBEDDING_OUTPUT_TENSORS = "async_embedding_output_tensors"
6107+
61056108
# Key to indicate various ops.
61066109
INIT_OP = "init_op"
61076110
LOCAL_INIT_OP = "local_init_op"

tensorflow/python/ops/embedding_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def _embedding_lookup_and_transform(params,
339339
if not transform_fn:
340340
# If transform_fn was provided, the clip_by_norm was done above.
341341
ret = _clip(ret, ids, max_norm)
342+
ops.add_to_collections(ops.GraphKeys.ASYNC_EMBEDDING_OUTPUT_TENSORS, ret)
342343
return ret
343344

344345

@@ -672,6 +673,7 @@ def embedding_lookup_sparse(params,
672673
else:
673674
assert False, "Unrecognized combiner"
674675

676+
ops.add_to_collections(ops.GraphKeys.ASYNC_EMBEDDING_OUTPUT_TENSORS, embeddings)
675677
return embeddings
676678

677679
@tf_export(v1=["nn.adaptive_embedding_lookup_sparse"])
@@ -1341,6 +1343,7 @@ def safe_embedding_lookup_sparse(embedding_weights,
13411343
tensor_shape.unknown_shape(
13421344
(tensor_shape.Dimension(original_rank_dim) - 1).value).concatenate(
13431345
result.get_shape()[1:]))
1346+
ops.add_to_collections(ops.GraphKeys.ASYNC_EMBEDDING_OUTPUT_TENSORS, final_result)
13441347
return final_result
13451348

13461349
def fused_safe_embedding_lookup_sparse(embedding_weights,
@@ -1420,6 +1423,7 @@ def fused_safe_embedding_lookup_sparse(embedding_weights,
14201423
tensor_shape.unknown_shape(
14211424
(tensor_shape.Dimension(original_rank_dim) - 1).value).concatenate(
14221425
result.get_shape()[1:]))
1426+
ops.add_to_collections(ops.GraphKeys.ASYNC_EMBEDDING_OUTPUT_TENSORS, final_result)
14231427
return final_result
14241428

14251429
@tf_export("nn.safe_embedding_lookup_multi_dim")

tensorflow/python/ops/fused_embedding_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def fused_embedding_lookup_sparse(params,
8282
partitioned_values=partitioned_values,
8383
combiner=combiner, max_norm=max_norm, default_id=default_id
8484
)
85+
ops.add_to_collections(ops.GraphKeys.ASYNC_EMBEDDING_OUTPUT_TENSORS, emb_vectors)
8586
return emb_vectors
8687

8788

0 commit comments

Comments
 (0)