Skip to content

Commit fdc7e32

Browse files
authored
Merge pull request #588 from zhijxu-MS/fix_transpose_add_bug
fix bug of tranpose opt with broadcasting op
2 parents 64e4208 + e1f837a commit fdc7e32

File tree

2 files changed

+89
-5
lines changed

2 files changed

+89
-5
lines changed

tests/test_optimizers.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,48 @@ def test_transpose_with_concat(self):
9090
}
9191
self.run_transpose_compare(["res"], feed_dict, model_proto, remaining_transpose_num=1)
9292

93+
def test_transpose_with_add1(self):
94+
# when transpose follows with a broadcasting op
95+
# reshape is needed when switching transpose with this op and op need broadcast its inputs
96+
node1 = helper.make_node("Transpose", ["input_data1"], ["Y"], perm=[0, 2, 3, 1], name="trans")
97+
node2 = helper.make_node("Add", ["Y", "input_data2"], ["Z"], name="add")
98+
node3 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans2")
99+
100+
graph = helper.make_graph(
101+
[node1, node2, node3],
102+
"transpose_with_shape",
103+
[helper.make_tensor_value_info("input_data1", TensorProto.FLOAT, (2, 3, 4, 5)),
104+
helper.make_tensor_value_info("input_data2", TensorProto.FLOAT, (3,)),
105+
],
106+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (2, 3, 4, 5))],
107+
)
108+
109+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
110+
feed_dict = {"input_data1": np.random.randn(2, 3, 4, 5).astype(np.float32),
111+
"input_data2": np.random.randn(3).astype(np.float32),
112+
}
113+
self.run_transpose_compare(["res"], feed_dict, model_proto, remaining_transpose_num=0)
114+
115+
def test_transpose_with_add2(self):
116+
node1 = helper.make_node("Transpose", ["input_data1"], ["Y"], perm=[0, 2, 3, 1], name="trans")
117+
node2 = helper.make_node("Add", ["Y", "input_data2"], ["Z"], name="add")
118+
node3 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans2")
119+
120+
graph = helper.make_graph(
121+
[node1, node2, node3],
122+
"transpose_with_shape",
123+
[helper.make_tensor_value_info("input_data1", TensorProto.FLOAT, (2, 3, 4, 5)),
124+
helper.make_tensor_value_info("input_data2", TensorProto.FLOAT, (2, 4, 5, 3)),
125+
],
126+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (2, 3, 4, 5))],
127+
)
128+
129+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
130+
feed_dict = {"input_data1": np.random.randn(2, 3, 4, 5).astype(np.float32),
131+
"input_data2": np.random.randn(2, 4, 5, 3).astype(np.float32),
132+
}
133+
self.run_transpose_compare(["res"], feed_dict, model_proto, remaining_transpose_num=1)
134+
93135
def test_transpose_relu(self):
94136
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
95137
node2 = helper.make_node("Relu", ["Y"], ["Z"], name="relu")

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections import defaultdict
88

99
import numpy as np
10-
10+
import onnx
1111
from tf2onnx.constants import NCHW_TO_NHWC, NHWC_TO_NCHW
1212
from .. import utils
1313
from .optimizer_base import GraphOptimizerBase
@@ -191,8 +191,8 @@ def _initialize_handlers(self):
191191

192192
def _handle_node_having_branches(self, node):
193193
# create transpose pairs if some input are not.
194-
self._create_transpose_pairs_before_node(node)
195-
194+
if not self._create_transpose_pairs_before_node(node):
195+
return False
196196
# make sure node's all input transpose all have only 1 consumer node,
197197
# otherwise, it would impact their other output nodes
198198
if self._nodes_has_single_consumer_node(node.inputs):
@@ -307,6 +307,16 @@ def _create_transpose_pairs_after_node(self, node):
307307
self._g.replace_input(consumer, node.output[0], nhwc_node.output[0])
308308

309309
def _create_transpose_pairs_before_node(self, node):
310+
def shape_after_expand(ori_shape):
311+
# according to broadcasting rule to expand shape to 4D while not tile the tensor here
312+
# still count on the broadcasting op to tile the tensor
313+
if ori_shape.count(-1) >= 2:
314+
self.logger.warning("%s shape can contain one -1 at most, otherwise reshape op can't work", node.name)
315+
return None
316+
ori_rank = len(ori_shape)
317+
new_shape = [1]*(4-ori_rank) + ori_shape
318+
return new_shape
319+
310320
non_nhwc_trans_inputs = []
311321
for input_id, n in zip(node.input, node.inputs):
312322
if not is_nhwc_transpose(n):
@@ -315,10 +325,42 @@ def _create_transpose_pairs_before_node(self, node):
315325
non_nhwc_trans_inputs.append([input_id, n])
316326

317327
# add Transpose(0, 3, 1, 2) and Transpose(0, 2, 3, 1) before each non_nhwc_trans_consumers
328+
shape_unknow = [input_id for input_id, _ in non_nhwc_trans_inputs if self._g.get_shape(input_id) is None]
329+
if shape_unknow:
330+
if self._g.opset <= 9:
331+
msg = "%s 's shape is unknown, ConstantOfShape will be used which exists in version 9 or higher" \
332+
"while graph's opset version is %s" % (shape_unknow, self._g.opset)
333+
self.logger.warning(msg)
334+
return False
335+
318336
for input_id, n in non_nhwc_trans_inputs:
319-
nchw_node = self._g.make_node("Transpose", [input_id], attr={"perm": [0, 3, 1, 2]})
320-
nhwc_node = self._g.make_node("Transpose", [nchw_node.output[0]], attr={"perm": [0, 2, 3, 1]})
337+
shape = self._g.get_shape(input_id)
338+
# if rank of n is not 4, then we need to insert a reshape op before inserting a transpose
339+
# for example shape of n is [x, y], then output shape of reshape will be [1, 1, x, y]
340+
if shape is None:
341+
const_4 = self._g.make_const(utils.make_name("const_4"), np.array([4], np.int64)).output[0]
342+
tensor_1 = onnx.helper.make_tensor("value", onnx.TensorProto.INT64, [1], [1])
343+
shape_node = self._g.make_node("Shape", [input_id]).output[0]
344+
rank_node = self._g.make_node("Shape", [shape_node]).output[0]
345+
expand_rank = self._g.make_node("Sub", [const_4, rank_node]).output[0]
346+
array_fill_1 = self._g.make_node("ConstantOfShape", [expand_rank], attr={"value": tensor_1}).output[0]
347+
new_shape = self._g.make_node("Concat", [array_fill_1, shape_node], attr={"axis": 0}).output[0]
348+
reshape = self._g.make_node("Reshape", [input_id, new_shape]).output[0]
349+
input_of_new_trans = reshape
350+
elif len(shape) == 4:
351+
input_of_new_trans = input_id
352+
else:
353+
shape_4d = shape_after_expand(shape)
354+
if shape_4d is None:
355+
return False
356+
const = self._g.make_const(utils.make_name("reshape_shape"), np.array(shape_4d, np.int64)).output[0]
357+
reshape = self._g.make_node("Reshape", [input_id, const]).output[0]
358+
input_of_new_trans = reshape
359+
360+
nchw_node = self._g.make_node("Transpose", [input_of_new_trans], attr={"perm": NHWC_TO_NCHW})
361+
nhwc_node = self._g.make_node("Transpose", [nchw_node.output[0]], attr={"perm": NCHW_TO_NHWC})
321362
self._g.replace_input(node, input_id, nhwc_node.output[0])
363+
return True
322364

323365
def _add_handler(self, trans, node):
324366
if node.inputs[1].is_const():

0 commit comments

Comments
 (0)