Skip to content

Commit bcf626c

Browse files
authored
Merge pull request #975 from onnx/gs/lstm-tf-1.15
handle 2 switch output consumers
2 parents 3bf9eae + 6c82816 commit bcf626c

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

tf2onnx/rewriter/loop_rewriter_base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,15 @@ def _get_loop_var_from_switch(self, switch_node):
352352
# using grappler there is not necessarily an identity behind switch
353353
switch_true_identity_output = switch_node.output[1]
354354
else:
355-
raise ValueError("switch_true " + switch_node.name + " has unexpected count of consumers:",
356-
[n.name for n in switch_consumers])
355+
# insert identity if there are 2 or more consumers. This can happen on tf-1.15.
356+
switch_true_identity_output = self.g.make_node("Identity", [switch_node.output[1]],
357+
shapes=[switch_node.output_shapes[1]],
358+
dtypes=[switch_node.output_dtypes[1]])
359+
switch_true_identity_output = switch_true_identity_output.output[0]
360+
for n in switch_consumers:
361+
for i, nn in enumerate(n.input):
362+
if nn == switch_node.output[1]:
363+
n.input[i] = switch_true_identity_output
357364

358365
target_node_input_id = None
359366
enter_node = [n for n in merge_node.inputs if n.type == 'Enter'][0]

0 commit comments

Comments
 (0)