@@ -52,18 +52,19 @@ def __init__(self):
52
52
self .tensor_array_inputs = [] # list of type InputTensorArray
53
53
54
54
def add_variable (self , var ):
55
- utils .make_sure (var .enter_name not in self .scan_variables ,
56
- "variable %s already exists as scan variable." , var .enter_name )
57
- utils .make_sure (var .enter_name not in self .state_variables ,
58
- "variable %s already exists as state variable." , var .enter_name )
55
+ key = (var .enter_name , var .merge_name )
56
+ utils .make_sure (key not in self .scan_variables ,
57
+ "variable %r already exists as scan variable." , key )
58
+ utils .make_sure (key not in self .state_variables ,
59
+ "variable %r already exists as state variable." , key )
59
60
if var .tensor_array_type == TensorArrayVariableType .READ_LAST :
60
61
# If the variable just returns the last value of the constructed tensor array, it doesn't need to be
61
62
# a scan output
62
- self .unneeded_scan_variables [var . enter_name ] = var
63
+ self .unneeded_scan_variables [key ] = var
63
64
elif var .tensor_array_type == TensorArrayVariableType .GATHER_ALL :
64
- self .scan_variables [var . enter_name ] = var
65
+ self .scan_variables [key ] = var
65
66
else :
66
- self .state_variables [var . enter_name ] = var
67
+ self .state_variables [key ] = var
67
68
68
69
def get_variables (self , checker ):
69
70
if not checker :
@@ -146,9 +147,10 @@ class LoopVariable(object):
146
147
5. the body graph output might go to next iteration as corresponding input
147
148
(e.g. switch_true_identity_output.id).
148
149
"""
149
- def __init__ (self , enter_name , enter_input_id , next_iteration_input_id ,
150
+ def __init__ (self , enter_name , merge_name , enter_input_id , next_iteration_input_id ,
150
151
switch_true_identity_output_id , exit_output_id , tensor_array_type , ta_index_id , g ):
151
152
self .enter_name = enter_name
153
+ self .merge_name = merge_name
152
154
self .enter_input_id = enter_input_id
153
155
154
156
# the output of iteration body graph for this variable
@@ -330,7 +332,7 @@ def _crop_loop_condition_sub_graph(self, context):
330
332
dependent_vars = []
331
333
for merge_node in merge_nodes :
332
334
enter_node = [n for n in merge_node .inputs if n .type == "Enter" ][0 ]
333
- loop_var = context .loop_properties .all_variables [enter_node .name ]
335
+ loop_var = context .loop_properties .all_variables [( enter_node .name , merge_node . name ) ]
334
336
335
337
# cut off connection between condition graph and Merge node.
336
338
# replace condition graph's inputs to be cell graph's outputs, because we want condition graph
@@ -447,7 +449,7 @@ def _get_loop_var_from_switch(self, switch_node):
447
449
# update exit output id, treat the gather output as ta's output
448
450
exit_output_id = ta_access_node .output [0 ]
449
451
450
- loop_var = LoopVariable (enter_node .name , target_node_input_id , last_iteration_output_id ,
452
+ loop_var = LoopVariable (enter_node .name , merge_node . name , target_node_input_id , last_iteration_output_id ,
451
453
switch_true_identity_output , exit_output_id , ta_type , ta_index_id , self .g )
452
454
453
455
return loop_var
0 commit comments