@@ -339,6 +339,7 @@ def _embedding_lookup_and_transform(params,
339339 if not transform_fn :
340340 # If transform_fn was provided, the clip_by_norm was done above.
341341 ret = _clip (ret , ids , max_norm )
342+ ops .add_to_collections (ops .GraphKeys .ASYNC_EMBEDDING_OUTPUT_TENSORS , ret )
342343 return ret
343344
344345
@@ -672,6 +673,7 @@ def embedding_lookup_sparse(params,
672673 else :
673674 assert False , "Unrecognized combiner"
674675
676+ ops .add_to_collections (ops .GraphKeys .ASYNC_EMBEDDING_OUTPUT_TENSORS , embeddings )
675677 return embeddings
676678
677679@tf_export (v1 = ["nn.adaptive_embedding_lookup_sparse" ])
@@ -1341,6 +1343,7 @@ def safe_embedding_lookup_sparse(embedding_weights,
13411343 tensor_shape .unknown_shape (
13421344 (tensor_shape .Dimension (original_rank_dim ) - 1 ).value ).concatenate (
13431345 result .get_shape ()[1 :]))
1346+ ops .add_to_collections (ops .GraphKeys .ASYNC_EMBEDDING_OUTPUT_TENSORS , final_result )
13441347 return final_result
13451348
13461349def fused_safe_embedding_lookup_sparse (embedding_weights ,
@@ -1420,6 +1423,7 @@ def fused_safe_embedding_lookup_sparse(embedding_weights,
14201423 tensor_shape .unknown_shape (
14211424 (tensor_shape .Dimension (original_rank_dim ) - 1 ).value ).concatenate (
14221425 result .get_shape ()[1 :]))
1426+ ops .add_to_collections (ops .GraphKeys .ASYNC_EMBEDDING_OUTPUT_TENSORS , final_result )
14231427 return final_result
14241428
14251429@tf_export ("nn.safe_embedding_lookup_multi_dim" )
0 commit comments