@@ -793,3 +793,77 @@ class IsNan:
793
793
@classmethod
794
794
def version_9 (cls , ctx , node , ** kwargs ):
795
795
pass
796
+
797
+
798
+ @tf_op ("BatchToSpaceND" , onnx_op = "DepthToSpace" )
799
+ class BatchToSpace :
800
+ @classmethod
801
+ def version_4 (cls , ctx , node , ** kwargs ):
802
+ # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d.html
803
+ # the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape)
804
+ # and we only support 4D here, so the data format is NHWC
805
+ # onnx op "DepthToSpace" does the same work on input tensor except that it works on "C",
806
+ # and it only supports NCHW
807
+ # T out = BatchToSpaceND(T input, int32 block_shape, int32 crops)
808
+ input_tensor = node .inputs [0 ]
809
+ blocksize = node .inputs [1 ].get_tensor_value ()
810
+ crops = node .inputs [2 ].get_tensor_value ()
811
+
812
+ utils .make_sure (len (ctx .get_shape (input_tensor .output [0 ])) == 4 , "only supports 4D for now" )
813
+ utils .make_sure (len (blocksize ) == 2 and blocksize [0 ] == blocksize [1 ],
814
+ "only support same blocksize at different dims" )
815
+
816
+ ctx .remove_node (node .name )
817
+ # NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
818
+ trans1 = ctx .make_node ("Transpose" , input_tensor .output , {"perm" : [3 , 0 , 1 , 2 ]})
819
+ reorganize_node = ctx .make_node (node .type , trans1 .output , attr = {"blocksize" : blocksize [0 ]})
820
+ trans2 = ctx .make_node ("Transpose" , reorganize_node .output , {"perm" : [1 , 2 , 3 , 0 ]})
821
+
822
+ # implement crop logic, the data format is NHWC
823
+ slice_axis = [1 , 2 ]
824
+ top , bottom = crops [0 ]
825
+ left , right = crops [1 ]
826
+ starts = [top , left ]
827
+ ends = []
828
+ for end in [bottom , right ]:
829
+ if end != 0 :
830
+ ends .append (- end )
831
+ else :
832
+ ends .append (np .iinfo (np .int32 ).max )
833
+
834
+ ctx .make_node ("Slice" , trans2 .output , attr = {"axes" : slice_axis , "ends" : ends , "starts" : starts },
835
+ name = node .name , outputs = node .output )
836
+
837
+
838
+ @tf_op ("SpaceToBatchND" , onnx_op = "SpaceToDepth" )
839
+ class SpaceToBatch :
840
+ @classmethod
841
+ def version_4 (cls , ctx , node , ** kwargs ):
842
+ # https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd
843
+ # the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape)
844
+ # and we only support 4D here, so the data format is NHWC
845
+ # onnx op "SpaceToDepth" does the same work on input tensor except that it works on "C",
846
+ # and it only supports NCHW
847
+ # T out = SpaceToBatchND(T input, int32 block_shape, int32 crops)
848
+ input_tensor = node .inputs [0 ]
849
+ blocksize = node .inputs [1 ].get_tensor_value ()
850
+ paddings = node .inputs [2 ].get_tensor_value ()
851
+
852
+ utils .make_sure (len (ctx .get_shape (input_tensor .output [0 ])) == 4 , "only supports 4D for now" )
853
+ utils .make_sure (len (blocksize ) == 2 and blocksize [0 ] == blocksize [1 ],
854
+ "only support same blocksize at different dims" )
855
+
856
+ ctx .remove_node (node .name )
857
+
858
+ # implement pads logic, the data format is NHWC
859
+ top , bottom = paddings [0 ]
860
+ left , right = paddings [1 ]
861
+ pads = [0 , top , left , 0 ,
862
+ 0 , bottom , right , 0 ]
863
+
864
+ pad_op = ctx .make_node ("Pad" , input_tensor .output , attr = {"pads" : pads })
865
+
866
+ # NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
867
+ trans1 = ctx .make_node ("Transpose" , pad_op .output , {"perm" : [3 , 0 , 1 , 2 ]})
868
+ reorganize_node = ctx .make_node (node .type , trans1 .output , attr = {"blocksize" : blocksize [0 ]})
869
+ ctx .make_node ("Transpose" , reorganize_node .output , {"perm" : [1 , 2 , 3 , 0 ]}, name = node .name , outputs = node .output )
0 commit comments