Skip to content

Commit be36457

Browse files
Fix tf1 loop variable parsing for duplicated Enter nodes (#1697)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent f545163 commit be36457

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

tf2onnx/rewriter/loop_rewriter_base.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,18 +52,19 @@ def __init__(self):
5252
self.tensor_array_inputs = [] # list of type InputTensorArray
5353

5454
def add_variable(self, var):
55-
utils.make_sure(var.enter_name not in self.scan_variables,
56-
"variable %s already exists as scan variable.", var.enter_name)
57-
utils.make_sure(var.enter_name not in self.state_variables,
58-
"variable %s already exists as state variable.", var.enter_name)
55+
key = (var.enter_name, var.merge_name)
56+
utils.make_sure(key not in self.scan_variables,
57+
"variable %r already exists as scan variable.", key)
58+
utils.make_sure(key not in self.state_variables,
59+
"variable %r already exists as state variable.", key)
5960
if var.tensor_array_type == TensorArrayVariableType.READ_LAST:
6061
# If the variable just returns the last value of the constructed tensor array, it doesn't need to be
6162
# a scan output
62-
self.unneeded_scan_variables[var.enter_name] = var
63+
self.unneeded_scan_variables[key] = var
6364
elif var.tensor_array_type == TensorArrayVariableType.GATHER_ALL:
64-
self.scan_variables[var.enter_name] = var
65+
self.scan_variables[key] = var
6566
else:
66-
self.state_variables[var.enter_name] = var
67+
self.state_variables[key] = var
6768

6869
def get_variables(self, checker):
6970
if not checker:
@@ -146,9 +147,10 @@ class LoopVariable(object):
146147
5. the body graph output might go to next iteration as corresponding input
147148
(e.g. switch_true_identity_output.id).
148149
"""
149-
def __init__(self, enter_name, enter_input_id, next_iteration_input_id,
150+
def __init__(self, enter_name, merge_name, enter_input_id, next_iteration_input_id,
150151
switch_true_identity_output_id, exit_output_id, tensor_array_type, ta_index_id, g):
151152
self.enter_name = enter_name
153+
self.merge_name = merge_name
152154
self.enter_input_id = enter_input_id
153155

154156
# the output of iteration body graph for this variable
@@ -330,7 +332,7 @@ def _crop_loop_condition_sub_graph(self, context):
330332
dependent_vars = []
331333
for merge_node in merge_nodes:
332334
enter_node = [n for n in merge_node.inputs if n.type == "Enter"][0]
333-
loop_var = context.loop_properties.all_variables[enter_node.name]
335+
loop_var = context.loop_properties.all_variables[(enter_node.name, merge_node.name)]
334336

335337
# cut off connection between condition graph and Merge node.
336338
# replace condition graph's inputs to be cell graph's outputs, because we want condition graph
@@ -447,7 +449,7 @@ def _get_loop_var_from_switch(self, switch_node):
447449
# update exit output id, treat the gather output as ta's output
448450
exit_output_id = ta_access_node.output[0]
449451

450-
loop_var = LoopVariable(enter_node.name, target_node_input_id, last_iteration_output_id,
452+
loop_var = LoopVariable(enter_node.name, merge_node.name, target_node_input_id, last_iteration_output_id,
451453
switch_true_identity_output, exit_output_id, ta_type, ta_index_id, self.g)
452454

453455
return loop_var

0 commit comments

Comments
 (0)