@@ -1073,16 +1073,17 @@ def version_1(cls, ctx, node, **kwargs):
1073
1073
# and it only supports NCHW
1074
1074
# T out = BatchToSpaceND(T input, int32 block_shape, int32 crops)
1075
1075
input_tensor = node .inputs [0 ]
1076
+ input_shape = ctx .get_shape (input_tensor .output [0 ])
1076
1077
blocksize = node .inputs [1 ].get_tensor_value ()
1077
1078
crops = node .inputs [2 ].get_tensor_value ()
1078
1079
1079
- utils .make_sure (len (ctx . get_shape ( input_tensor . output [ 0 ]) ) in (4 , 3 ),
1080
+ utils .make_sure (len (input_shape ) in (4 , 3 ),
1080
1081
"only supports 3D and 4D for now" )
1081
1082
utils .make_sure (len (blocksize ) == 2 and blocksize [0 ] == blocksize [1 ],
1082
1083
"only support same blocksize at different dims" )
1083
1084
1084
1085
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
1085
- if len (ctx . get_shape ( input_tensor . output [ 0 ]) ) == 3 :
1086
+ if len (input_shape ) == 3 :
1086
1087
# insert automatically an Unsqueeze op if the input is 3d
1087
1088
unsqz1 = ctx .make_node ("Unsqueeze" , input_tensor .output , {"axes" : [3 ]})
1088
1089
trans1 = ctx .make_node ("Transpose" , unsqz1 .output , {"perm" : [3 , 0 , 1 , 2 ]})
@@ -1105,19 +1106,20 @@ def version_1(cls, ctx, node, **kwargs):
1105
1106
1106
1107
attr = {"axes" : slice_axis , "ends" : ends , "starts" : starts }
1107
1108
inputs_map = {"data" : trans2 .output [0 ], ** attr }
1108
- dtypes = [ ctx . get_dtype ( node .output [ 0 ])]
1109
- shapes = ctx . get_shape ( node .output [ 0 ])
1109
+ dtypes = node .output_dtypes
1110
+ shapes = node .output_shapes
1110
1111
1111
- if len (ctx . get_shape ( input_tensor . output [ 0 ]) ) == 3 :
1112
+ if len (input_shape ) == 3 :
1112
1113
# add a squeeze op to convert output into 3d
1113
1114
kwargs = {** inputs_map }
1114
1115
ctx .remove_node (node .name )
1115
1116
slice1 = GraphBuilder (ctx ).make_slice (kwargs )
1116
- ctx .make_node ("Squeeze" , [slice1 ], {"axes" : [3 ]}, outputs = node .output , name = node .name , dtypes = dtypes )
1117
+ ctx .make_node ("Squeeze" , [slice1 ], {"axes" : [3 ]},
1118
+ outputs = node .output , name = node .name , dtypes = dtypes , shapes = shapes )
1117
1119
else :
1118
1120
kwargs = {** inputs_map , "outputs" : node .output }
1119
1121
ctx .remove_node (node .name )
1120
- GraphBuilder (ctx ).make_slice (kwargs , name = node .name , dtypes = dtypes , shapes = [ shapes ] )
1122
+ GraphBuilder (ctx ).make_slice (kwargs , name = node .name , dtypes = dtypes , shapes = shapes )
1121
1123
1122
1124
1123
1125
@tf_op ("SpaceToBatchND" , onnx_op = "SpaceToDepth" )
0 commit comments