Skip to content

Commit 7538f59

Browse files
authored
Merge pull request #472 from mindest/batch_space
implement onnx ops for SpaceToBatchND and BatchToSpaceND
2 parents 27f5e67 + 96ff4fb commit 7538f59

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

tests/test_backend.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2043,5 +2043,30 @@ def test_softsign(self):
20432043
_ = tf.identity(x_, name=_TFOUTPUT)
20442044
self._run_test_case([_OUTPUT], {_INPUT: x_val})
20452045

2046+
def test_batch_to_spacend(self):
2047+
block_size = [2, 2]
2048+
crop = [[0, 1], [2, 1]]
2049+
2050+
input_val = np.random.random_sample([40, 3, 5, 100]).astype(np.float32)
2051+
input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT) # NHWC
2052+
_ = tf.batch_to_space_nd(input_x, block_size, crop, name=_TFOUTPUT)
2053+
self._run_test_case([_OUTPUT], {_INPUT: input_val})
2054+
2055+
def test_space_to_batchnd(self):
2056+
block_size = [2, 2]
2057+
pad = [[0, 1], [2, 1]]
2058+
input_val = np.random.random_sample([40, 5, 7, 66]).astype(np.float32)
2059+
input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT) # NHWC
2060+
_ = tf.space_to_batch_nd(input_x, block_size, pad, name=_TFOUTPUT)
2061+
self._run_test_case([_OUTPUT], {_INPUT: input_val})
2062+
2063+
tf.reset_default_graph()
2064+
2065+
pad = [[0, 0], [1, 2]]
2066+
input_val = np.random.random_sample([10, 6, 7, 66]).astype(np.float32)
2067+
input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT) # NHWC
2068+
_ = tf.space_to_batch_nd(input_x, block_size, pad, name=_TFOUTPUT)
2069+
self._run_test_case([_OUTPUT], {_INPUT: input_val})
2070+
20462071
if __name__ == '__main__':
20472072
unittest_main()

tf2onnx/onnx_opset/tensor.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,3 +793,77 @@ class IsNan:
793793
@classmethod
794794
def version_9(cls, ctx, node, **kwargs):
795795
pass
796+
797+
798+
@tf_op("BatchToSpaceND", onnx_op="DepthToSpace")
799+
class BatchToSpace:
800+
@classmethod
801+
def version_4(cls, ctx, node, **kwargs):
802+
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d.html
803+
# the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape)
804+
# and we only support 4D here, so the data format is NHWC
805+
# onnx op "DepthToSpace" does the same work on input tensor except that it works on "C",
806+
# and it only supports NCHW
807+
# T out = BatchToSpaceND(T input, int32 block_shape, int32 crops)
808+
input_tensor = node.inputs[0]
809+
blocksize = node.inputs[1].get_tensor_value()
810+
crops = node.inputs[2].get_tensor_value()
811+
812+
utils.make_sure(len(ctx.get_shape(input_tensor.output[0])) == 4, "only supports 4D for now")
813+
utils.make_sure(len(blocksize) == 2 and blocksize[0] == blocksize[1],
814+
"only support same blocksize at different dims")
815+
816+
ctx.remove_node(node.name)
817+
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
818+
trans1 = ctx.make_node("Transpose", input_tensor.output, {"perm": [3, 0, 1, 2]})
819+
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]})
820+
trans2 = ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]})
821+
822+
# implement crop logic, the data format is NHWC
823+
slice_axis = [1, 2]
824+
top, bottom = crops[0]
825+
left, right = crops[1]
826+
starts = [top, left]
827+
ends = []
828+
for end in [bottom, right]:
829+
if end != 0:
830+
ends.append(-end)
831+
else:
832+
ends.append(np.iinfo(np.int32).max)
833+
834+
ctx.make_node("Slice", trans2.output, attr={"axes": slice_axis, "ends": ends, "starts": starts},
835+
name=node.name, outputs=node.output)
836+
837+
838+
@tf_op("SpaceToBatchND", onnx_op="SpaceToDepth")
839+
class SpaceToBatch:
840+
@classmethod
841+
def version_4(cls, ctx, node, **kwargs):
842+
# https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd
843+
# the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape)
844+
# and we only support 4D here, so the data format is NHWC
845+
# onnx op "SpaceToDepth" does the same work on input tensor except that it works on "C",
846+
# and it only supports NCHW
847+
# T out = SpaceToBatchND(T input, int32 block_shape, int32 crops)
848+
input_tensor = node.inputs[0]
849+
blocksize = node.inputs[1].get_tensor_value()
850+
paddings = node.inputs[2].get_tensor_value()
851+
852+
utils.make_sure(len(ctx.get_shape(input_tensor.output[0])) == 4, "only supports 4D for now")
853+
utils.make_sure(len(blocksize) == 2 and blocksize[0] == blocksize[1],
854+
"only support same blocksize at different dims")
855+
856+
ctx.remove_node(node.name)
857+
858+
# implement pads logic, the data format is NHWC
859+
top, bottom = paddings[0]
860+
left, right = paddings[1]
861+
pads = [0, top, left, 0,
862+
0, bottom, right, 0]
863+
864+
pad_op = ctx.make_node("Pad", input_tensor.output, attr={"pads": pads})
865+
866+
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
867+
trans1 = ctx.make_node("Transpose", pad_op.output, {"perm": [3, 0, 1, 2]})
868+
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]})
869+
ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]}, name=node.name, outputs=node.output)

0 commit comments

Comments
 (0)