Skip to content

Commit 97cd275

Browse files
committed
move op defs to tensor.py
1 parent 3e5bf77 commit 97cd275

File tree

2 files changed

+74
-74
lines changed

2 files changed

+74
-74
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 0 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -441,80 +441,6 @@ def version_7(cls, ctx, node, **kwargs):
441441
conv_convert_inputs(ctx, node, with_kernel=False)
442442

443443

444-
@tf_op("BatchToSpaceND", onnx_op="DepthToSpace")
445-
class BatchToSpace:
446-
@classmethod
447-
def version_4(cls, ctx, node, **kwargs):
448-
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d.html
449-
# the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape)
450-
# and we only support 4D here, so the data format is NHWC
451-
# onnx op "DepthToSpace" does the same work on input tensor except that it works on "C",
452-
# and it only supports NCHW
453-
# T out = BatchToSpaceND(T input, int32 block_shape, int32 crops)
454-
input_tensor = node.inputs[0]
455-
blocksize = node.inputs[1].get_tensor_value()
456-
crops = node.inputs[2].get_tensor_value()
457-
458-
utils.make_sure(len(ctx.get_shape(input_tensor.output[0])) == 4, "only supports 4D for now")
459-
utils.make_sure(len(blocksize) == 2 and blocksize[0] == blocksize[1],
460-
"only support same blocksize at different dims")
461-
462-
ctx.remove_node(node.name)
463-
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
464-
trans1 = ctx.make_node("Transpose", input_tensor.output, {"perm": [3, 0, 1, 2]})
465-
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]})
466-
trans2 = ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]})
467-
468-
# implement crop logic, the data format is NHWC
469-
slice_axis = [1, 2]
470-
top, bottom = crops[0]
471-
left, right = crops[1]
472-
starts = [top, left]
473-
ends = []
474-
for end in [bottom, right]:
475-
if end != 0:
476-
ends.append(-end)
477-
else:
478-
ends.append(np.iinfo(np.int32).max)
479-
480-
ctx.make_node("Slice", trans2.output, attr={"axes": slice_axis, "ends": ends, "starts": starts},
481-
name=node.name, outputs=node.output)
482-
483-
484-
@tf_op("SpaceToBatchND", onnx_op="SpaceToDepth")
485-
class SpaceToBatch:
486-
@classmethod
487-
def version_4(cls, ctx, node, **kwargs):
488-
# https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd
489-
# the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape)
490-
# and we only support 4D here, so the data format is NHWC
491-
# onnx op "SpaceToDepth" does the same work on input tensor except that it works on "C",
492-
# and it only supports NCHW
493-
# T out = SpaceToBatchND(T input, int32 block_shape, int32 crops)
494-
input_tensor = node.inputs[0]
495-
blocksize = node.inputs[1].get_tensor_value()
496-
paddings = node.inputs[2].get_tensor_value()
497-
498-
utils.make_sure(len(ctx.get_shape(input_tensor.output[0])) == 4, "only supports 4D for now")
499-
utils.make_sure(len(blocksize) == 2 and blocksize[0] == blocksize[1],
500-
"only support same blocksize at different dims")
501-
502-
ctx.remove_node(node.name)
503-
504-
# implement pads logic, the data format is NHWC
505-
top, bottom = paddings[0]
506-
left, right = paddings[1]
507-
pads = [0, top, left, 0,
508-
0, bottom, right, 0]
509-
510-
pad_op = ctx.make_node("Pad", input_tensor.output, attr={"pads": pads})
511-
512-
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
513-
trans1 = ctx.make_node("Transpose", pad_op.output, {"perm": [3, 0, 1, 2]})
514-
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]})
515-
ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]}, name=node.name, outputs=node.output)
516-
517-
518444
@tf_op(["ResizeBilinear", "ResizeNearestNeighbor"])
519445
class ResizeX:
520446
@classmethod

tf2onnx/onnx_opset/tensor.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,3 +793,77 @@ class IsNan:
793793
@classmethod
794794
def version_9(cls, ctx, node, **kwargs):
795795
pass
796+
797+
798+
@tf_op("BatchToSpaceND", onnx_op="DepthToSpace")
799+
class BatchToSpace:
800+
@classmethod
801+
def version_4(cls, ctx, node, **kwargs):
802+
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d.html
803+
# the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape)
804+
# and we only support 4D here, so the data format is NHWC
805+
# onnx op "DepthToSpace" does the same work on input tensor except that it works on "C",
806+
# and it only supports NCHW
807+
# T out = BatchToSpaceND(T input, int32 block_shape, int32 crops)
808+
input_tensor = node.inputs[0]
809+
blocksize = node.inputs[1].get_tensor_value()
810+
crops = node.inputs[2].get_tensor_value()
811+
812+
utils.make_sure(len(ctx.get_shape(input_tensor.output[0])) == 4, "only supports 4D for now")
813+
utils.make_sure(len(blocksize) == 2 and blocksize[0] == blocksize[1],
814+
"only support same blocksize at different dims")
815+
816+
ctx.remove_node(node.name)
817+
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
818+
trans1 = ctx.make_node("Transpose", input_tensor.output, {"perm": [3, 0, 1, 2]})
819+
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]})
820+
trans2 = ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]})
821+
822+
# implement crop logic, the data format is NHWC
823+
slice_axis = [1, 2]
824+
top, bottom = crops[0]
825+
left, right = crops[1]
826+
starts = [top, left]
827+
ends = []
828+
for end in [bottom, right]:
829+
if end != 0:
830+
ends.append(-end)
831+
else:
832+
ends.append(np.iinfo(np.int32).max)
833+
834+
ctx.make_node("Slice", trans2.output, attr={"axes": slice_axis, "ends": ends, "starts": starts},
835+
name=node.name, outputs=node.output)
836+
837+
838+
@tf_op("SpaceToBatchND", onnx_op="SpaceToDepth")
839+
class SpaceToBatch:
840+
@classmethod
841+
def version_4(cls, ctx, node, **kwargs):
842+
# https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd
843+
# the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape)
844+
# and we only support 4D here, so the data format is NHWC
845+
# onnx op "SpaceToDepth" does the same work on input tensor except that it works on "C",
846+
# and it only supports NCHW
847+
# T out = SpaceToBatchND(T input, int32 block_shape, int32 crops)
848+
input_tensor = node.inputs[0]
849+
blocksize = node.inputs[1].get_tensor_value()
850+
paddings = node.inputs[2].get_tensor_value()
851+
852+
utils.make_sure(len(ctx.get_shape(input_tensor.output[0])) == 4, "only supports 4D for now")
853+
utils.make_sure(len(blocksize) == 2 and blocksize[0] == blocksize[1],
854+
"only support same blocksize at different dims")
855+
856+
ctx.remove_node(node.name)
857+
858+
# implement pads logic, the data format is NHWC
859+
top, bottom = paddings[0]
860+
left, right = paddings[1]
861+
pads = [0, top, left, 0,
862+
0, bottom, right, 0]
863+
864+
pad_op = ctx.make_node("Pad", input_tensor.output, attr={"pads": pads})
865+
866+
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
867+
trans1 = ctx.make_node("Transpose", pad_op.output, {"perm": [3, 0, 1, 2]})
868+
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]})
869+
ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]}, name=node.name, outputs=node.output)

0 commit comments

Comments
 (0)