Skip to content

Commit be62ec3

Browse files
authored
[Graph] Fix hang bug for async embedding lookup. (#934)
Skip edges to 'SaveV3' Op. Signed-off-by: chenbangduo.cbd <[email protected]>
1 parent 06f81cc commit be62ec3

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

tensorflow/python/training/async_embedding_stage.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,14 @@ def __init__(self, options, checkpoint_dir = None):
4949
self._checkpoint_dir = checkpoint_dir if checkpoint_dir else ""
5050
self._use_stage_subgraph_thread_pool = options.use_stage_subgraph_thread_pool
5151
self._stage_subgraph_thread_pool_id = options.stage_subgraph_thread_pool_id
52+
self._is_staged = False
5253
self._control_flow_ops = ['Switch', '_SwitchN', 'Merge', '_XlaMerge',
5354
'Enter', 'Exit']
5455
self._variable_ops = ['Variable', 'VariableV2', 'VarHandleOp',
5556
'KvVarHandleOp', 'HashTableV2']
5657
self._variable_is_init_ops = ['IsVariableInitialized',
5758
'VarIsInitializedOp', 'KvVarIsInitializedOp']
58-
self._saver_ops = ['SaveV2']
59+
self._saver_ops = ['SaveV2', 'SaveV3']
5960
self._no_data_input_ops = self._variable_ops + ['Placeholder', 'PlaceholderV2', 'Const']
6061
self._boundary_ops = set()
6162
for tensor in ops.get_collection(ops.GraphKeys.ASYNC_EMBEDDING_OUTPUT_TENSORS):
@@ -74,6 +75,10 @@ def __init__(self, options, checkpoint_dir = None):
7475
def stage(self, graph):
7576
""" add async embedding stage node to graph
7677
"""
78+
if self._is_staged:
79+
return
80+
self._is_staged = True
81+
7782
logging.info('async embedding stage begin')
7883
logging.info('async embedding thread num: ' + str(self._threads_num))
7984
logging.info('async embedding capacity: ' + str(self._capacity))

tensorflow/python/training/monitored_session.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def __init__(self,
185185
self._saver = saver
186186
self._incremental_save_restore = incremental_save_restore
187187
self._incr_saver = None
188+
self._async_embedding_stage = None
188189
self._enable_async_embedding = False
189190
self._async_embedding_checkpoint_dir = None
190191
self._async_embedding_options = None
@@ -247,10 +248,11 @@ def default_ready_for_local_init_op():
247248
self._incr_saver = incr_saver._get_incremental_saver(self._incremental_save_restore, self._saver)
248249

249250
if self._enable_async_embedding:
250-
async_embedding_stage = async_embedding.AsyncEmbeddingStage(
251-
self._async_embedding_options,
252-
self._async_embedding_checkpoint_dir)
253-
async_embedding_stage.stage(ops.get_default_graph())
251+
if self._async_embedding_stage is None:
252+
self._async_embedding_stage = async_embedding.AsyncEmbeddingStage(
253+
self._async_embedding_options,
254+
self._async_embedding_checkpoint_dir)
255+
self._async_embedding_stage.stage(ops.get_default_graph())
254256

255257
ops.get_default_graph().finalize()
256258
logging.info('Graph was finalized.')

0 commit comments

Comments
 (0)