7
7
from collections import defaultdict
8
8
9
9
import numpy as np
10
-
11
- from tf2onnx .constants import NCHW_TO_NHWC , NHWC_TO_NCHW
12
10
import onnx
11
+ from tf2onnx .constants import NCHW_TO_NHWC , NHWC_TO_NCHW
13
12
from .. import utils
14
13
from .optimizer_base import GraphOptimizerBase
15
14
@@ -192,8 +191,8 @@ def _initialize_handlers(self):
192
191
193
192
def _handle_node_having_branches (self , node ):
194
193
# create transpose pairs if some input are not.
195
- self ._create_transpose_pairs_before_node (node )
196
-
194
+ if not self ._create_transpose_pairs_before_node (node ):
195
+ return False
197
196
# make sure node's all input transpose all have only 1 consumer node,
198
197
# otherwise, it would impact their other output nodes
199
198
if self ._nodes_has_single_consumer_node (node .inputs ):
@@ -311,9 +310,10 @@ def _create_transpose_pairs_before_node(self, node):
311
310
def shape_after_expand (ori_shape ):
312
311
# according to broadcasting rule to expand shape to 4D while not tile the tensor here
313
312
# still count on the broadcasting op to tile the tensor
314
- utils .make_sure (ori_shape .count (- 1 ) <= 1 , "shape can contain one -1 at most" )
313
+ if ori_shape .count (- 1 ) >= 2 :
314
+ self .logger .warning ("%s shape can contain one -1 at most, otherwise reshape op can't work" , node .name )
315
+ return None
315
316
ori_rank = len (ori_shape )
316
- utils .make_sure (ori_rank <= 4 , "ONNX only supports 4D data" )
317
317
new_shape = [1 ]* (4 - ori_rank ) + ori_shape
318
318
return new_shape
319
319
@@ -325,6 +325,14 @@ def shape_after_expand(ori_shape):
325
325
non_nhwc_trans_inputs .append ([input_id , n ])
326
326
327
327
# add Transpose(0, 3, 1, 2) and Transpose(0, 2, 3, 1) before each non_nhwc_trans_consumers
328
+ shape_unknow = [input_id for input_id , _ in non_nhwc_trans_inputs if self ._g .get_shape (input_id ) is None ]
329
+ if shape_unknow :
330
+ if self ._g .opset <= 9 :
331
+ msg = "%s 's shape is unknown, ConstantOfShape will be used which exists in version 9 or higher" \
332
+ "while graph's opset version is %s" % (shape_unknow , self ._g .opset )
333
+ self .logger .warning (msg )
334
+ return False
335
+
328
336
for input_id , n in non_nhwc_trans_inputs :
329
337
shape = self ._g .get_shape (input_id )
330
338
# if rank of n is not 4, then we need to insert a reshape op before inserting a transpose
@@ -343,13 +351,16 @@ def shape_after_expand(ori_shape):
343
351
input_of_new_trans = input_id
344
352
else :
345
353
shape_4d = shape_after_expand (shape )
354
+ if shape_4d is None :
355
+ return False
346
356
const = self ._g .make_const (utils .make_name ("reshape_shape" ), np .array (shape_4d , np .int64 )).output [0 ]
347
357
reshape = self ._g .make_node ("Reshape" , [input_id , const ]).output [0 ]
348
358
input_of_new_trans = reshape
349
359
350
- nchw_node = self ._g .make_node ("Transpose" , [input_of_new_trans ], attr = {"perm" : [ 0 , 3 , 1 , 2 ] })
351
- nhwc_node = self ._g .make_node ("Transpose" , [nchw_node .output [0 ]], attr = {"perm" : [ 0 , 2 , 3 , 1 ] })
360
+ nchw_node = self ._g .make_node ("Transpose" , [input_of_new_trans ], attr = {"perm" : NHWC_TO_NCHW })
361
+ nhwc_node = self ._g .make_node ("Transpose" , [nchw_node .output [0 ]], attr = {"perm" : NCHW_TO_NHWC })
352
362
self ._g .replace_input (node , input_id , nhwc_node .output [0 ])
363
+ return True
353
364
354
365
def _add_handler (self , trans , node ):
355
366
if node .inputs [1 ].is_const ():
0 commit comments