Skip to content

Commit d73a80d

Browse files
committed
make some amelioration according to comments
1 parent a90ca4d commit d73a80d

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,16 +1073,17 @@ def version_1(cls, ctx, node, **kwargs):
10731073
# and it only supports NCHW
10741074
# T out = BatchToSpaceND(T input, int32 block_shape, int32 crops)
10751075
input_tensor = node.inputs[0]
1076+
input_shape = ctx.get_shape(input_tensor.output[0])
10761077
blocksize = node.inputs[1].get_tensor_value()
10771078
crops = node.inputs[2].get_tensor_value()
10781079

1079-
utils.make_sure(len(ctx.get_shape(input_tensor.output[0])) in (4, 3),
1080+
utils.make_sure(len(input_shape) in (4, 3),
10801081
"only supports 3D and 4D for now")
10811082
utils.make_sure(len(blocksize) == 2 and blocksize[0] == blocksize[1],
10821083
"only support same blocksize at different dims")
10831084

10841085
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
1085-
if len(ctx.get_shape(input_tensor.output[0])) == 3:
1086+
if len(input_shape) == 3:
10861087
# insert automatically an Unsqueeze op if the input is 3d
10871088
unsqz1 = ctx.make_node("Unsqueeze", input_tensor.output, {"axes": [3]})
10881089
trans1 = ctx.make_node("Transpose", unsqz1.output, {"perm": [3, 0, 1, 2]})
@@ -1105,19 +1106,20 @@ def version_1(cls, ctx, node, **kwargs):
11051106

11061107
attr = {"axes": slice_axis, "ends": ends, "starts": starts}
11071108
inputs_map = {"data": trans2.output[0], **attr}
1108-
dtypes = [ctx.get_dtype(node.output[0])]
1109-
shapes = ctx.get_shape(node.output[0])
1109+
dtypes = node.output_dtypes
1110+
shapes = node.output_shapes
11101111

1111-
if len(ctx.get_shape(input_tensor.output[0])) == 3:
1112+
if len(input_shape) == 3:
11121113
# add a squeeze op to convert output into 3d
11131114
kwargs = {**inputs_map}
11141115
ctx.remove_node(node.name)
11151116
slice1 = GraphBuilder(ctx).make_slice(kwargs)
1116-
ctx.make_node("Squeeze", [slice1], {"axes": [3]}, outputs=node.output, name=node.name, dtypes=dtypes)
1117+
ctx.make_node("Squeeze", [slice1], {"axes": [3]},
1118+
outputs=node.output, name=node.name, dtypes=dtypes, shapes=shapes)
11171119
else:
11181120
kwargs = {**inputs_map, "outputs": node.output}
11191121
ctx.remove_node(node.name)
1120-
GraphBuilder(ctx).make_slice(kwargs, name=node.name, dtypes=dtypes, shapes=[shapes])
1122+
GraphBuilder(ctx).make_slice(kwargs, name=node.name, dtypes=dtypes, shapes=shapes)
11211123

11221124

11231125
@tf_op("SpaceToBatchND", onnx_op="SpaceToDepth")

0 commit comments

Comments
 (0)