Skip to content

Commit 3e5bf77

Browse files
committed
implement onnx ops for SpaceToBatchND and BatchToSpaceND
1 parent d03e469 commit 3e5bf77

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

tests/test_backend.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2016,5 +2016,30 @@ def test_softsign(self):
20162016
_ = tf.identity(x_, name=_TFOUTPUT)
20172017
self._run_test_case([_OUTPUT], {_INPUT: x_val})
20182018

2019+
def test_batch_to_spacend(self):
2020+
block_size = [2, 2]
2021+
crop = [[0, 1], [2, 1]]
2022+
2023+
input_val = np.random.random_sample([40, 3, 5, 100]).astype(np.float32)
2024+
input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT) # NHWC
2025+
_ = tf.batch_to_space_nd(input_x, block_size, crop, name=_TFOUTPUT)
2026+
self._run_test_case([_OUTPUT], {_INPUT: input_val})
2027+
2028+
def test_space_to_batchnd(self):
2029+
block_size = [2, 2]
2030+
pad = [[0, 1], [2, 1]]
2031+
input_val = np.random.random_sample([40, 5, 7, 66]).astype(np.float32)
2032+
input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT) # NHWC
2033+
_ = tf.space_to_batch_nd(input_x, block_size, pad, name=_TFOUTPUT)
2034+
self._run_test_case([_OUTPUT], {_INPUT: input_val})
2035+
2036+
tf.reset_default_graph()
2037+
2038+
pad = [[0, 0], [1, 2]]
2039+
input_val = np.random.random_sample([10, 6, 7, 66]).astype(np.float32)
2040+
input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT) # NHWC
2041+
_ = tf.space_to_batch_nd(input_x, block_size, pad, name=_TFOUTPUT)
2042+
self._run_test_case([_OUTPUT], {_INPUT: input_val})
2043+
20192044
if __name__ == '__main__':
20202045
unittest_main()

tf2onnx/onnx_opset/nn.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,80 @@ def version_7(cls, ctx, node, **kwargs):
441441
conv_convert_inputs(ctx, node, with_kernel=False)
442442

443443

444+
@tf_op("BatchToSpaceND", onnx_op="DepthToSpace")
445+
class BatchToSpace:
446+
@classmethod
447+
def version_4(cls, ctx, node, **kwargs):
448+
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d.html
449+
# the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape)
450+
# and we only support 4D here, so the data format is NHWC
451+
# onnx op "DepthToSpace" does the same work on input tensor except that it works on "C",
452+
# and it only supports NCHW
453+
# T out = BatchToSpaceND(T input, int32 block_shape, int32 crops)
454+
input_tensor = node.inputs[0]
455+
blocksize = node.inputs[1].get_tensor_value()
456+
crops = node.inputs[2].get_tensor_value()
457+
458+
utils.make_sure(len(ctx.get_shape(input_tensor.output[0])) == 4, "only supports 4D for now")
459+
utils.make_sure(len(blocksize) == 2 and blocksize[0] == blocksize[1],
460+
"only support same blocksize at different dims")
461+
462+
ctx.remove_node(node.name)
463+
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
464+
trans1 = ctx.make_node("Transpose", input_tensor.output, {"perm": [3, 0, 1, 2]})
465+
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]})
466+
trans2 = ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]})
467+
468+
# implement crop logic, the data format is NHWC
469+
slice_axis = [1, 2]
470+
top, bottom = crops[0]
471+
left, right = crops[1]
472+
starts = [top, left]
473+
ends = []
474+
for end in [bottom, right]:
475+
if end != 0:
476+
ends.append(-end)
477+
else:
478+
ends.append(np.iinfo(np.int32).max)
479+
480+
ctx.make_node("Slice", trans2.output, attr={"axes": slice_axis, "ends": ends, "starts": starts},
481+
name=node.name, outputs=node.output)
482+
483+
484+
@tf_op("SpaceToBatchND", onnx_op="SpaceToDepth")
485+
class SpaceToBatch:
486+
@classmethod
487+
def version_4(cls, ctx, node, **kwargs):
488+
# https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd
489+
# the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape)
490+
# and we only support 4D here, so the data format is NHWC
491+
# onnx op "SpaceToDepth" does the same work on input tensor except that it works on "C",
492+
# and it only supports NCHW
493+
# T out = SpaceToBatchND(T input, int32 block_shape, int32 crops)
494+
input_tensor = node.inputs[0]
495+
blocksize = node.inputs[1].get_tensor_value()
496+
paddings = node.inputs[2].get_tensor_value()
497+
498+
utils.make_sure(len(ctx.get_shape(input_tensor.output[0])) == 4, "only supports 4D for now")
499+
utils.make_sure(len(blocksize) == 2 and blocksize[0] == blocksize[1],
500+
"only support same blocksize at different dims")
501+
502+
ctx.remove_node(node.name)
503+
504+
# implement pads logic, the data format is NHWC
505+
top, bottom = paddings[0]
506+
left, right = paddings[1]
507+
pads = [0, top, left, 0,
508+
0, bottom, right, 0]
509+
510+
pad_op = ctx.make_node("Pad", input_tensor.output, attr={"pads": pads})
511+
512+
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
513+
trans1 = ctx.make_node("Transpose", pad_op.output, {"perm": [3, 0, 1, 2]})
514+
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]})
515+
ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]}, name=node.name, outputs=node.output)
516+
517+
444518
@tf_op(["ResizeBilinear", "ResizeNearestNeighbor"])
445519
class ResizeX:
446520
@classmethod

0 commit comments

Comments
 (0)