Skip to content

Commit 33430ef

Browse files
authored
[Graph] Fix placement issue of stage node when enabling async embedding. (#452)
1 parent 9f694d8 commit 33430ef

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tensorflow/python/training/async_embedding_stage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def _perform_stage(self):
247247

248248
stage_outputs[stage_node.name] = stage_node_outputs
249249
stage_outputs_consumers[stage_node.name] = stage_node_outputs_consumers
250-
with ops.device("/job:worker"):
250+
with ops.colocate_with(list(self._stage_nodes.keys())[0]):
251251
stage_output_result = prefetch.staged(stage_outputs,
252252
num_threads=self._threads_num,
253253
capacity=self._capacity,

0 commit comments

Comments
 (0)