Skip to content

Commit 10e79a9

Browse files
committed
refactor code
1 parent f05e36d commit 10e79a9

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

tf2onnx/rewriter/cond_rewriter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ def _get_output_shape_dtype(self, cond_context):
127127
false_dtype
128128
)
129129
)
130+
# in tf, the shape of different branched can be different,
131+
# for example output shape of branch A can be [-1] while branch B can be [1].
132+
# Under this case, we should set output shape to be [-1]
130133
output_shapes.append(utils.create_vague_shape_like(utils.merge_shapes(true_shape, false_shape)))
131134
output_dtypes.append(true_dtype)
132135
return output_shapes, output_dtypes

tf2onnx/tfonnx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,8 @@ def rewrite_incomplete_type_support_rs6(g, ops):
475475
"Div",
476476
"Greater",
477477
"IsNaN",
478+
"Less",
479+
"Max",
478480
"Min",
479481
"ReduceSum",
480482
"Slice",

0 commit comments

Comments
 (0)