Skip to content

Commit 330fd73

Browse files
committed
fix ut failure , root cause: loo_rewirter bug
1 parent 14c5735 commit 330fd73

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

tf2onnx/rewriter/cond_rewriter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _get_output_shape_dtype(self, cond_context):
120120
)
121121
if true_dtype != false_dtype:
122122
raise RuntimeError(
123-
"the shape of outputs {} and {} mismatch: {}, {}".format(
123+
"the dtype of outputs {} and {} mismatch: {}, {}".format(
124124
true_output,
125125
false_output,
126126
true_dtype,

tf2onnx/rewriter/loop_rewriter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def rewrite(self, context):
102102
return REWRITER_RESULT.FAIL
103103

104104
def _create_loop_node(self, context, loop_props):
105-
# reuse original output connection id (e.g. Exit_XXX), so we don't need set shape.
106105
loop_outputs = []
107106
loop_output_shapes = []
108107
loop_output_dtypes = []
@@ -125,6 +124,7 @@ def _create_loop_node(self, context, loop_props):
125124
loop_node = self.g.make_node("Loop", [trip_cnt.output[0]] + [cond.output[0]] +
126125
loop_props.state_inputs_initial_values, # ONNX Loop support state inputs only
127126
outputs=loop_outputs, op_name_scope="generic_loop",
127+
shapes=loop_output_shapes, dtypes=loop_output_dtypes,
128128
skip_conversion=False)
129129

130130
return loop_node

0 commit comments

Comments
 (0)