@@ -307,6 +307,14 @@ def _create_transpose_pairs_after_node(self, node):
307
307
self ._g .replace_input (consumer , node .output [0 ], nhwc_node .output [0 ])
308
308
309
309
def _create_transpose_pairs_before_node (self , node ):
310
+ def shape_after_expand (ori_shape ):
311
+ # according to broadcasting rule to expand shape to 4D
312
+ utils .make_sure (ori_shape .count (- 1 ) <= 1 , "shape can contain one -1 at most" )
313
+ ori_rank = len (ori_shape )
314
+ utils .make_sure (ori_rank <= 4 , "ONNX only supports 4D data" )
315
+ new_shape = [1 ]* (4 - ori_rank ) + ori_shape
316
+ return new_shape
317
+
310
318
non_nhwc_trans_inputs = []
311
319
for input_id , n in zip (node .input , node .inputs ):
312
320
if not is_nhwc_transpose (n ):
@@ -316,7 +324,16 @@ def _create_transpose_pairs_before_node(self, node):
316
324
317
325
# add Transpose(0, 3, 1, 2) and Transpose(0, 2, 3, 1) before each non_nhwc_trans_consumers
318
326
for input_id , n in non_nhwc_trans_inputs :
319
- nchw_node = self ._g .make_node ("Transpose" , [input_id ], attr = {"perm" : [0 , 3 , 1 , 2 ]})
327
+ shape = self ._g .get_shape (n .output [0 ])
328
+ if len (shape ) == 4 :
329
+ nchw_node = self ._g .make_node ("Transpose" , [input_id ], attr = {"perm" : [0 , 3 , 1 , 2 ]})
330
+ else :
331
+ shape_4d = shape_after_expand (shape )
332
+ shape_const = self ._g .make_const (utils .make_name ("reshape_shape" ),
333
+ np_val = np .array (shape_4d , np .int64 )).output [0 ]
334
+ reshape = self ._g .make_node ("Reshape" , [input_id , shape_const ]).output [0 ]
335
+ nchw_node = self ._g .make_node ("Transpose" , [reshape ], attr = {"perm" : [0 , 3 , 1 , 2 ]})
336
+
320
337
nhwc_node = self ._g .make_node ("Transpose" , [nchw_node .output [0 ]], attr = {"perm" : [0 , 2 , 3 , 1 ]})
321
338
self ._g .replace_input (node , input_id , nhwc_node .output [0 ])
322
339
0 commit comments