Skip to content

Commit c682ae6

Browse files
Fix bug in while shapes (#1479)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent b002761 commit c682ae6

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

tf2onnx/onnx_opset/controlflow.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -510,9 +510,8 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_
510510
g.set_dtype(func_inputs[0], onnx_pb.TensorProto.INT64)
511511
g.inputs = [g.get_node_by_output(inp) for inp in func_inputs]
512512

513-
for p, c in zip(loop_node.inputs, func_inputs):
514-
shape = p.output_shapes[0]
515-
g.set_shape(c, shape)
513+
for p, c in zip(loop_node.input, func_inputs):
514+
g.copy_shape(p, c)
516515

517516
for i, node in enumerate(g.inputs):
518517
if node.output[0] not in func_inputs:

0 commit comments

Comments
 (0)