Skip to content

Commit 64e4208

Browse files
author
wayuanho
authored
Merge pull request #598 from lei-Qiao/batch_to_space_3d
Enhance BatchToSpaceND to support 3D input data
2 parents fe0614f + d73a80d commit 64e4208

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

tests/test_backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2366,6 +2366,15 @@ def test_batch_to_spacend(self):
23662366
_ = tf.batch_to_space_nd(input_x, block_size, crop, name=_TFOUTPUT)
23672367
self._run_test_case([_OUTPUT], {_INPUT: input_val})
23682368

2369+
def test_batch_to_space3d(self):
2370+
block_size = [2, 2]
2371+
crop = [[0, 1], [2, 1]]
2372+
2373+
input_val = np.random.random_sample([40, 3, 100]).astype(np.float32)
2374+
input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT) # NHC
2375+
_ = tf.batch_to_space_nd(input_x, block_size, crop, name=_TFOUTPUT)
2376+
self._run_test_case([_OUTPUT], {_INPUT: input_val})
2377+
23692378
def test_space_to_batchnd(self):
23702379
block_size = [2, 2]
23712380
pad = [[0, 1], [2, 1]]

tf2onnx/onnx_opset/tensor.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,20 +1068,27 @@ class BatchToSpace:
10681068
def version_1(cls, ctx, node, **kwargs):
10691069
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d.html
10701070
# the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape)
1071-
# and we only support 4D here, so the data format is NHWC
1071+
# and we only support 3D and 4D here, and the data format is NHC and NHWC
10721072
# onnx op "DepthToSpace" does the same work on input tensor except that it works on "C",
10731073
# and it only supports NCHW
10741074
# T out = BatchToSpaceND(T input, int32 block_shape, int32 crops)
10751075
input_tensor = node.inputs[0]
1076+
input_shape = ctx.get_shape(input_tensor.output[0])
10761077
blocksize = node.inputs[1].get_tensor_value()
10771078
crops = node.inputs[2].get_tensor_value()
10781079

1079-
utils.make_sure(len(ctx.get_shape(input_tensor.output[0])) == 4, "only supports 4D for now")
1080+
utils.make_sure(len(input_shape) in (4, 3),
1081+
"only supports 3D and 4D for now")
10801082
utils.make_sure(len(blocksize) == 2 and blocksize[0] == blocksize[1],
10811083
"only support same blocksize at different dims")
10821084

10831085
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
1084-
trans1 = ctx.make_node("Transpose", input_tensor.output, {"perm": [3, 0, 1, 2]})
1086+
if len(input_shape) == 3:
1087+
# insert automatically an Unsqueeze op if the input is 3d
1088+
unsqz1 = ctx.make_node("Unsqueeze", input_tensor.output, {"axes": [3]})
1089+
trans1 = ctx.make_node("Transpose", unsqz1.output, {"perm": [3, 0, 1, 2]})
1090+
else:
1091+
trans1 = ctx.make_node("Transpose", input_tensor.output, {"perm": [3, 0, 1, 2]})
10851092
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]})
10861093
trans2 = ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]})
10871094

@@ -1099,11 +1106,20 @@ def version_1(cls, ctx, node, **kwargs):
10991106

11001107
attr = {"axes": slice_axis, "ends": ends, "starts": starts}
11011108
inputs_map = {"data": trans2.output[0], **attr}
1102-
kwargs = {**inputs_map, "outputs": node.output}
1103-
dtypes = [ctx.get_dtype(node.output[0])]
1104-
shapes = [ctx.get_shape(node.output[0])]
1105-
ctx.remove_node(node.name)
1106-
GraphBuilder(ctx).make_slice(kwargs, name=node.name, dtypes=dtypes, shapes=shapes)
1109+
dtypes = node.output_dtypes
1110+
shapes = node.output_shapes
1111+
1112+
if len(input_shape) == 3:
1113+
# add a squeeze op to convert output into 3d
1114+
kwargs = {**inputs_map}
1115+
ctx.remove_node(node.name)
1116+
slice1 = GraphBuilder(ctx).make_slice(kwargs)
1117+
ctx.make_node("Squeeze", [slice1], {"axes": [3]},
1118+
outputs=node.output, name=node.name, dtypes=dtypes, shapes=shapes)
1119+
else:
1120+
kwargs = {**inputs_map, "outputs": node.output}
1121+
ctx.remove_node(node.name)
1122+
GraphBuilder(ctx).make_slice(kwargs, name=node.name, dtypes=dtypes, shapes=shapes)
11071123

11081124

11091125
@tf_op("SpaceToBatchND", onnx_op="SpaceToDepth")

0 commit comments

Comments
 (0)