Skip to content

Commit a90ca4d

Browse files
committed
Enhance BatchToSpaceND to support 3D input data
1 parent ffe6792 commit a90ca4d

File tree

2 files changed

+30
-7
lines changed

2 files changed

+30
-7
lines changed

tests/test_backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2359,6 +2359,15 @@ def test_batch_to_spacend(self):
23592359
_ = tf.batch_to_space_nd(input_x, block_size, crop, name=_TFOUTPUT)
23602360
self._run_test_case([_OUTPUT], {_INPUT: input_val})
23612361

2362+
def test_batch_to_space3d(self):
2363+
block_size = [2, 2]
2364+
crop = [[0, 1], [2, 1]]
2365+
2366+
input_val = np.random.random_sample([40, 3, 100]).astype(np.float32)
2367+
input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT) # NHC
2368+
_ = tf.batch_to_space_nd(input_x, block_size, crop, name=_TFOUTPUT)
2369+
self._run_test_case([_OUTPUT], {_INPUT: input_val})
2370+
23622371
def test_space_to_batchnd(self):
23632372
block_size = [2, 2]
23642373
pad = [[0, 1], [2, 1]]

tf2onnx/onnx_opset/tensor.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,20 +1068,26 @@ class BatchToSpace:
10681068
def version_1(cls, ctx, node, **kwargs):
10691069
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d.html
10701070
# the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape)
1071-
# and we only support 4D here, so the data format is NHWC
1071+
# and we only support 3D and 4D here, and the data format is NHC and NHWC
10721072
# onnx op "DepthToSpace" does the same work on input tensor except that it works on "C",
10731073
# and it only supports NCHW
10741074
# T out = BatchToSpaceND(T input, int32 block_shape, int32 crops)
10751075
input_tensor = node.inputs[0]
10761076
blocksize = node.inputs[1].get_tensor_value()
10771077
crops = node.inputs[2].get_tensor_value()
10781078

1079-
utils.make_sure(len(ctx.get_shape(input_tensor.output[0])) == 4, "only supports 4D for now")
1079+
utils.make_sure(len(ctx.get_shape(input_tensor.output[0])) in (4, 3),
1080+
"only supports 3D and 4D for now")
10801081
utils.make_sure(len(blocksize) == 2 and blocksize[0] == blocksize[1],
10811082
"only support same blocksize at different dims")
10821083

10831084
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
1084-
trans1 = ctx.make_node("Transpose", input_tensor.output, {"perm": [3, 0, 1, 2]})
1085+
if len(ctx.get_shape(input_tensor.output[0])) == 3:
1086+
# insert automatically an Unsqueeze op if the input is 3d
1087+
unsqz1 = ctx.make_node("Unsqueeze", input_tensor.output, {"axes": [3]})
1088+
trans1 = ctx.make_node("Transpose", unsqz1.output, {"perm": [3, 0, 1, 2]})
1089+
else:
1090+
trans1 = ctx.make_node("Transpose", input_tensor.output, {"perm": [3, 0, 1, 2]})
10851091
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]})
10861092
trans2 = ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]})
10871093

@@ -1099,11 +1105,19 @@ def version_1(cls, ctx, node, **kwargs):
10991105

11001106
attr = {"axes": slice_axis, "ends": ends, "starts": starts}
11011107
inputs_map = {"data": trans2.output[0], **attr}
1102-
kwargs = {**inputs_map, "outputs": node.output}
11031108
dtypes = [ctx.get_dtype(node.output[0])]
1104-
shapes = [ctx.get_shape(node.output[0])]
1105-
ctx.remove_node(node.name)
1106-
GraphBuilder(ctx).make_slice(kwargs, name=node.name, dtypes=dtypes, shapes=shapes)
1109+
shapes = ctx.get_shape(node.output[0])
1110+
1111+
if len(ctx.get_shape(input_tensor.output[0])) == 3:
1112+
# add a squeeze op to convert output into 3d
1113+
kwargs = {**inputs_map}
1114+
ctx.remove_node(node.name)
1115+
slice1 = GraphBuilder(ctx).make_slice(kwargs)
1116+
ctx.make_node("Squeeze", [slice1], {"axes": [3]}, outputs=node.output, name=node.name, dtypes=dtypes)
1117+
else:
1118+
kwargs = {**inputs_map, "outputs": node.output}
1119+
ctx.remove_node(node.name)
1120+
GraphBuilder(ctx).make_slice(kwargs, name=node.name, dtypes=dtypes, shapes=[shapes])
11071121

11081122

11091123
@tf_op("SpaceToBatchND", onnx_op="SpaceToDepth")

0 commit comments

Comments
 (0)