Skip to content

Commit e1f837a

Browse files
committed
refactor
1 parent d2930f8 commit e1f837a

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
from collections import defaultdict
88

99
import numpy as np
10-
11-
from tf2onnx.constants import NCHW_TO_NHWC, NHWC_TO_NCHW
1210
import onnx
11+
from tf2onnx.constants import NCHW_TO_NHWC, NHWC_TO_NCHW
1312
from .. import utils
1413
from .optimizer_base import GraphOptimizerBase
1514

@@ -192,8 +191,8 @@ def _initialize_handlers(self):
192191

193192
def _handle_node_having_branches(self, node):
194193
# create transpose pairs if some input are not.
195-
self._create_transpose_pairs_before_node(node)
196-
194+
if not self._create_transpose_pairs_before_node(node):
195+
return False
197196
# make sure node's all input transpose all have only 1 consumer node,
198197
# otherwise, it would impact their other output nodes
199198
if self._nodes_has_single_consumer_node(node.inputs):
@@ -311,9 +310,10 @@ def _create_transpose_pairs_before_node(self, node):
311310
def shape_after_expand(ori_shape):
312311
# according to broadcasting rule to expand shape to 4D while not tile the tensor here
313312
# still count on the broadcasting op to tile the tensor
314-
utils.make_sure(ori_shape.count(-1) <= 1, "shape can contain one -1 at most")
313+
if ori_shape.count(-1) >= 2:
314+
self.logger.warning("%s shape can contain one -1 at most, otherwise reshape op can't work", node.name)
315+
return None
315316
ori_rank = len(ori_shape)
316-
utils.make_sure(ori_rank <= 4, "ONNX only supports 4D data")
317317
new_shape = [1]*(4-ori_rank) + ori_shape
318318
return new_shape
319319

@@ -325,6 +325,14 @@ def shape_after_expand(ori_shape):
325325
non_nhwc_trans_inputs.append([input_id, n])
326326

327327
# add Transpose(0, 3, 1, 2) and Transpose(0, 2, 3, 1) before each non_nhwc_trans_consumers
328+
shape_unknow = [input_id for input_id, _ in non_nhwc_trans_inputs if self._g.get_shape(input_id) is None]
329+
if shape_unknow:
330+
if self._g.opset <= 9:
331+
msg = "%s 's shape is unknown, ConstantOfShape will be used which exists in version 9 or higher" \
332+
"while graph's opset version is %s" % (shape_unknow, self._g.opset)
333+
self.logger.warning(msg)
334+
return False
335+
328336
for input_id, n in non_nhwc_trans_inputs:
329337
shape = self._g.get_shape(input_id)
330338
# if rank of n is not 4, then we need to insert a reshape op before inserting a transpose
@@ -343,13 +351,16 @@ def shape_after_expand(ori_shape):
343351
input_of_new_trans = input_id
344352
else:
345353
shape_4d = shape_after_expand(shape)
354+
if shape_4d is None:
355+
return False
346356
const = self._g.make_const(utils.make_name("reshape_shape"), np.array(shape_4d, np.int64)).output[0]
347357
reshape = self._g.make_node("Reshape", [input_id, const]).output[0]
348358
input_of_new_trans = reshape
349359

350-
nchw_node = self._g.make_node("Transpose", [input_of_new_trans], attr={"perm": [0, 3, 1, 2]})
351-
nhwc_node = self._g.make_node("Transpose", [nchw_node.output[0]], attr={"perm": [0, 2, 3, 1]})
360+
nchw_node = self._g.make_node("Transpose", [input_of_new_trans], attr={"perm": NHWC_TO_NCHW})
361+
nhwc_node = self._g.make_node("Transpose", [nchw_node.output[0]], attr={"perm": NCHW_TO_NHWC})
352362
self._g.replace_input(node, input_id, nhwc_node.output[0])
363+
return True
353364

354365
def _add_handler(self, trans, node):
355366
if node.inputs[1].is_const():

0 commit comments

Comments
 (0)