7
7
from collections import defaultdict
8
8
9
9
import numpy as np
10
-
10
+ import onnx
11
11
from tf2onnx .constants import NCHW_TO_NHWC , NHWC_TO_NCHW
12
12
from .. import utils
13
13
from .optimizer_base import GraphOptimizerBase
@@ -191,8 +191,8 @@ def _initialize_handlers(self):
191
191
192
192
def _handle_node_having_branches (self , node ):
193
193
# create transpose pairs if some input are not.
194
- self ._create_transpose_pairs_before_node (node )
195
-
194
+ if not self ._create_transpose_pairs_before_node (node ):
195
+ return False
196
196
# make sure node's all input transpose all have only 1 consumer node,
197
197
# otherwise, it would impact their other output nodes
198
198
if self ._nodes_has_single_consumer_node (node .inputs ):
@@ -307,6 +307,16 @@ def _create_transpose_pairs_after_node(self, node):
307
307
self ._g .replace_input (consumer , node .output [0 ], nhwc_node .output [0 ])
308
308
309
309
def _create_transpose_pairs_before_node (self , node ):
310
+ def shape_after_expand (ori_shape ):
311
+ # according to broadcasting rule to expand shape to 4D while not tile the tensor here
312
+ # still count on the broadcasting op to tile the tensor
313
+ if ori_shape .count (- 1 ) >= 2 :
314
+ self .logger .warning ("%s shape can contain one -1 at most, otherwise reshape op can't work" , node .name )
315
+ return None
316
+ ori_rank = len (ori_shape )
317
+ new_shape = [1 ]* (4 - ori_rank ) + ori_shape
318
+ return new_shape
319
+
310
320
non_nhwc_trans_inputs = []
311
321
for input_id , n in zip (node .input , node .inputs ):
312
322
if not is_nhwc_transpose (n ):
@@ -315,10 +325,42 @@ def _create_transpose_pairs_before_node(self, node):
315
325
non_nhwc_trans_inputs .append ([input_id , n ])
316
326
317
327
# add Transpose(0, 3, 1, 2) and Transpose(0, 2, 3, 1) before each non_nhwc_trans_consumers
328
+ shape_unknow = [input_id for input_id , _ in non_nhwc_trans_inputs if self ._g .get_shape (input_id ) is None ]
329
+ if shape_unknow :
330
+ if self ._g .opset <= 9 :
331
+ msg = "%s 's shape is unknown, ConstantOfShape will be used which exists in version 9 or higher" \
332
+ "while graph's opset version is %s" % (shape_unknow , self ._g .opset )
333
+ self .logger .warning (msg )
334
+ return False
335
+
318
336
for input_id , n in non_nhwc_trans_inputs :
319
- nchw_node = self ._g .make_node ("Transpose" , [input_id ], attr = {"perm" : [0 , 3 , 1 , 2 ]})
320
- nhwc_node = self ._g .make_node ("Transpose" , [nchw_node .output [0 ]], attr = {"perm" : [0 , 2 , 3 , 1 ]})
337
+ shape = self ._g .get_shape (input_id )
338
+ # if rank of n is not 4, then we need to insert a reshape op before inserting a transpose
339
+ # for example shape of n is [x, y], then output shape of reshape will be [1, 1, x, y]
340
+ if shape is None :
341
+ const_4 = self ._g .make_const (utils .make_name ("const_4" ), np .array ([4 ], np .int64 )).output [0 ]
342
+ tensor_1 = onnx .helper .make_tensor ("value" , onnx .TensorProto .INT64 , [1 ], [1 ])
343
+ shape_node = self ._g .make_node ("Shape" , [input_id ]).output [0 ]
344
+ rank_node = self ._g .make_node ("Shape" , [shape_node ]).output [0 ]
345
+ expand_rank = self ._g .make_node ("Sub" , [const_4 , rank_node ]).output [0 ]
346
+ array_fill_1 = self ._g .make_node ("ConstantOfShape" , [expand_rank ], attr = {"value" : tensor_1 }).output [0 ]
347
+ new_shape = self ._g .make_node ("Concat" , [array_fill_1 , shape_node ], attr = {"axis" : 0 }).output [0 ]
348
+ reshape = self ._g .make_node ("Reshape" , [input_id , new_shape ]).output [0 ]
349
+ input_of_new_trans = reshape
350
+ elif len (shape ) == 4 :
351
+ input_of_new_trans = input_id
352
+ else :
353
+ shape_4d = shape_after_expand (shape )
354
+ if shape_4d is None :
355
+ return False
356
+ const = self ._g .make_const (utils .make_name ("reshape_shape" ), np .array (shape_4d , np .int64 )).output [0 ]
357
+ reshape = self ._g .make_node ("Reshape" , [input_id , const ]).output [0 ]
358
+ input_of_new_trans = reshape
359
+
360
+ nchw_node = self ._g .make_node ("Transpose" , [input_of_new_trans ], attr = {"perm" : NHWC_TO_NCHW })
361
+ nhwc_node = self ._g .make_node ("Transpose" , [nchw_node .output [0 ]], attr = {"perm" : NCHW_TO_NHWC })
321
362
self ._g .replace_input (node , input_id , nhwc_node .output [0 ])
363
+ return True
322
364
323
365
def _add_handler (self , trans , node ):
324
366
if node .inputs [1 ].is_const ():
0 commit comments