Skip to content

Commit 2dcfc75

Browse files
committed
fix bug of tranpose opt with broadcasting op
1 parent fe0614f commit 2dcfc75

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

tests/test_optimizers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,27 @@ def test_transpose_with_concat(self):
9090
}
9191
self.run_transpose_compare(["res"], feed_dict, model_proto, remaining_transpose_num=1)
9292

93+
def test_transpose_with_add(self):
94+
# when transpose follows with a broadcasting op
95+
# reshape is needed when switching transpose with this op and op need broadcast its inputs
96+
node1 = helper.make_node("Transpose", ["input_data1"], ["Y"], perm=[0, 2, 3, 1], name="trans")
97+
node2 = helper.make_node("Add", ["Y", "input_data2"], ["Z"], name="add")
98+
99+
graph = helper.make_graph(
100+
[node1, node2],
101+
"transpose_with_shape",
102+
[helper.make_tensor_value_info("input_data1", TensorProto.FLOAT, (2, 3, 4, 5)),
103+
helper.make_tensor_value_info("input_data2", TensorProto.FLOAT, (3,)),
104+
],
105+
[helper.make_tensor_value_info("Z", TensorProto.FLOAT, [2, 4, 5, 3])],
106+
)
107+
108+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
109+
feed_dict = {"input_data1": np.random.randn(2, 3, 4, 5).astype(np.float32),
110+
"input_data2": np.random.randn(3).astype(np.float32),
111+
}
112+
self.run_transpose_compare(["Z"], feed_dict, model_proto, remaining_transpose_num=1)
113+
93114
def test_transpose_relu(self):
94115
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
95116
node2 = helper.make_node("Relu", ["Y"], ["Z"], name="relu")

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,14 @@ def _create_transpose_pairs_after_node(self, node):
307307
self._g.replace_input(consumer, node.output[0], nhwc_node.output[0])
308308

309309
def _create_transpose_pairs_before_node(self, node):
310+
def shape_after_expand(ori_shape):
311+
# according to broadcasting rule to expand shape to 4D
312+
utils.make_sure(ori_shape.count(-1) <= 1, "shape can contain one -1 at most")
313+
ori_rank = len(ori_shape)
314+
utils.make_sure(ori_rank <= 4, "ONNX only supports 4D data")
315+
new_shape = [1]*(4-ori_rank) + ori_shape
316+
return new_shape
317+
310318
non_nhwc_trans_inputs = []
311319
for input_id, n in zip(node.input, node.inputs):
312320
if not is_nhwc_transpose(n):
@@ -316,7 +324,16 @@ def _create_transpose_pairs_before_node(self, node):
316324

317325
# add Transpose(0, 3, 1, 2) and Transpose(0, 2, 3, 1) before each non_nhwc_trans_consumers
318326
for input_id, n in non_nhwc_trans_inputs:
319-
nchw_node = self._g.make_node("Transpose", [input_id], attr={"perm": [0, 3, 1, 2]})
327+
shape = self._g.get_shape(n.output[0])
328+
if len(shape) == 4:
329+
nchw_node = self._g.make_node("Transpose", [input_id], attr={"perm": [0, 3, 1, 2]})
330+
else:
331+
shape_4d = shape_after_expand(shape)
332+
shape_const = self._g.make_const(utils.make_name("reshape_shape"),
333+
np_val=np.array(shape_4d, np.int64)).output[0]
334+
reshape = self._g.make_node("Reshape", [input_id, shape_const]).output[0]
335+
nchw_node = self._g.make_node("Transpose", [reshape], attr={"perm": [0, 3, 1, 2]})
336+
320337
nhwc_node = self._g.make_node("Transpose", [nchw_node.output[0]], attr={"perm": [0, 2, 3, 1]})
321338
self._g.replace_input(node, input_id, nhwc_node.output[0])
322339

0 commit comments

Comments
 (0)