@@ -1068,20 +1068,27 @@ class BatchToSpace:
1068
1068
def version_1 (cls , ctx , node , ** kwargs ):
1069
1069
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d.html
1070
1070
# the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape)
1071
- # and we only support 4D here, so the data format is NHWC
1071
+ # and we only support 3D and 4D here, and the data format is NHC and NHWC
1072
1072
# onnx op "DepthToSpace" does the same work on input tensor except that it works on "C",
1073
1073
# and it only supports NCHW
1074
1074
# T out = BatchToSpaceND(T input, int32 block_shape, int32 crops)
1075
1075
input_tensor = node .inputs [0 ]
1076
+ input_shape = ctx .get_shape (input_tensor .output [0 ])
1076
1077
blocksize = node .inputs [1 ].get_tensor_value ()
1077
1078
crops = node .inputs [2 ].get_tensor_value ()
1078
1079
1079
- utils .make_sure (len (ctx .get_shape (input_tensor .output [0 ])) == 4 , "only supports 4D for now" )
1080
+ utils .make_sure (len (input_shape ) in (4 , 3 ),
1081
+ "only supports 3D and 4D for now" )
1080
1082
utils .make_sure (len (blocksize ) == 2 and blocksize [0 ] == blocksize [1 ],
1081
1083
"only support same blocksize at different dims" )
1082
1084
1083
1085
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
1084
- trans1 = ctx .make_node ("Transpose" , input_tensor .output , {"perm" : [3 , 0 , 1 , 2 ]})
1086
+ if len (input_shape ) == 3 :
1087
+ # insert automatically an Unsqueeze op if the input is 3d
1088
+ unsqz1 = ctx .make_node ("Unsqueeze" , input_tensor .output , {"axes" : [3 ]})
1089
+ trans1 = ctx .make_node ("Transpose" , unsqz1 .output , {"perm" : [3 , 0 , 1 , 2 ]})
1090
+ else :
1091
+ trans1 = ctx .make_node ("Transpose" , input_tensor .output , {"perm" : [3 , 0 , 1 , 2 ]})
1085
1092
reorganize_node = ctx .make_node (node .type , trans1 .output , attr = {"blocksize" : blocksize [0 ]})
1086
1093
trans2 = ctx .make_node ("Transpose" , reorganize_node .output , {"perm" : [1 , 2 , 3 , 0 ]})
1087
1094
@@ -1099,11 +1106,20 @@ def version_1(cls, ctx, node, **kwargs):
1099
1106
1100
1107
attr = {"axes" : slice_axis , "ends" : ends , "starts" : starts }
1101
1108
inputs_map = {"data" : trans2 .output [0 ], ** attr }
1102
- kwargs = {** inputs_map , "outputs" : node .output }
1103
- dtypes = [ctx .get_dtype (node .output [0 ])]
1104
- shapes = [ctx .get_shape (node .output [0 ])]
1105
- ctx .remove_node (node .name )
1106
- GraphBuilder (ctx ).make_slice (kwargs , name = node .name , dtypes = dtypes , shapes = shapes )
1109
+ dtypes = node .output_dtypes
1110
+ shapes = node .output_shapes
1111
+
1112
+ if len (input_shape ) == 3 :
1113
+ # add a squeeze op to convert output into 3d
1114
+ kwargs = {** inputs_map }
1115
+ ctx .remove_node (node .name )
1116
+ slice1 = GraphBuilder (ctx ).make_slice (kwargs )
1117
+ ctx .make_node ("Squeeze" , [slice1 ], {"axes" : [3 ]},
1118
+ outputs = node .output , name = node .name , dtypes = dtypes , shapes = shapes )
1119
+ else :
1120
+ kwargs = {** inputs_map , "outputs" : node .output }
1121
+ ctx .remove_node (node .name )
1122
+ GraphBuilder (ctx ).make_slice (kwargs , name = node .name , dtypes = dtypes , shapes = shapes )
1107
1123
1108
1124
1109
1125
@tf_op ("SpaceToBatchND" , onnx_op = "SpaceToDepth" )
0 commit comments