Skip to content

Commit 0de7d19

Browse files
authored
Merge pull request #451 from zhijxu-MS/push_branch
bug fixes
2 parents 42ade07 + 10e79a9 commit 0de7d19

File tree

5 files changed

+16
-4
lines changed

5 files changed

+16
-4
lines changed

tf2onnx/graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,8 +414,9 @@ def make_const(self, name, np_val, skip_conversion=False, raw=True):
414414
else:
415415
onnx_tensor = helper.make_tensor(name, utils.map_numpy_to_onnx_dtype(np_val.dtype),
416416
np_val.shape, np_val, raw=False)
417+
dtype = onnx_tensor.data_type
417418
node = self.make_node("Const", [], outputs=[name], name=name, attr={"value": onnx_tensor},
418-
skip_conversion=skip_conversion)
419+
skip_conversion=skip_conversion, dtypes=[dtype])
419420
self.set_shape(name, np_val.shape)
420421
self.set_dtype(name, utils.map_numpy_to_onnx_dtype(np_val.dtype))
421422
return node

tf2onnx/onnx_opset/nn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
6666
transpose.set_attr("perm", constants.NHWC_TO_NCHW)
6767
transpose.skip_conversion = True
6868
shape = ctx.get_shape(input_name)
69-
new_shape = spatial_map(shape, constants.NHWC_TO_NCHW)
70-
ctx.set_shape(transpose.output[0], new_shape)
69+
if shape is not None:
70+
new_shape = spatial_map(shape, constants.NHWC_TO_NCHW)
71+
ctx.set_shape(transpose.output[0], new_shape)
7172
parent.data_format = "NCHW"
7273

7374
# kernel must to be transposed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ def post_optimize_action(self):
8181
input_shape = self._g.get_shape(op.input[0])
8282
if not input_shape:
8383
continue
84+
# reshape only supports one dime is -1
85+
if input_shape.count(-1) > 1:
86+
continue
8487

8588
new_shape = []
8689
# when transpose is NHWC_TO_NCHW

tf2onnx/rewriter/cond_rewriter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,10 @@ def _get_output_shape_dtype(self, cond_context):
127127
false_dtype
128128
)
129129
)
130-
output_shapes.append(utils.merge_shapes(true_shape, false_shape))
130+
# in tf, the shape of different branched can be different,
131+
# for example output shape of branch A can be [-1] while branch B can be [1].
132+
# Under this case, we should set output shape to be [-1]
133+
output_shapes.append(utils.create_vague_shape_like(utils.merge_shapes(true_shape, false_shape)))
131134
output_dtypes.append(true_dtype)
132135
return output_shapes, output_dtypes
133136

tf2onnx/tfonnx.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,11 @@ def rewrite_incomplete_type_support_rs5(g, ops):
473473
def rewrite_incomplete_type_support_rs6(g, ops):
474474
return rewrite_incomplete_type_support(g, ops, [
475475
"Div",
476+
"Greater",
476477
"IsNaN",
478+
"Less",
479+
"Max",
480+
"Min",
477481
"ReduceSum",
478482
"Slice",
479483
"Split",

0 commit comments

Comments
 (0)