Skip to content

Commit 30fa0a8

Browse files
committed
fix bug in trans opt with concat
1 parent d657622 commit 30fa0a8

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

tests/test_optimizers.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,33 @@ def check_transpose_perm(self, model_proto, expected_perm):
6363
perm = list(node.attribute[0].ints)
6464
self.assertEqual(perm, expected_perm)
6565

66+
def test_transpose_with_concat(self):
67+
input_shape = (2, 3, 4, 5)
68+
perm = [0, 3, 1, 2]
69+
input_shape_with_trans = [input_shape[i] for i in perm]
70+
for axis in [0, 1, 2, 3]:
71+
output_before_trans = list(input_shape)
72+
output_before_trans[axis] *= 2
73+
output_shape = [output_before_trans[i] for i in [0, 3, 1, 2]]
74+
node1 = helper.make_node("Transpose", ["input_data1"], ["Y"], perm=[0, 2, 3, 1], name="trans")
75+
node2 = helper.make_node("Concat", ["Y", "input_data2"], ["Z"], axis=axis, name="concat")
76+
node3 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans2")
77+
78+
graph = helper.make_graph(
79+
[node1, node2, node3],
80+
"test_transpose_with_concat",
81+
[helper.make_tensor_value_info("input_data1", TensorProto.FLOAT, input_shape_with_trans),
82+
helper.make_tensor_value_info("input_data2", TensorProto.FLOAT, input_shape),
83+
],
84+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, output_shape)],
85+
)
86+
87+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
88+
feed_dict = {"input_data1": np.random.randn(*input_shape_with_trans).astype(np.float32),
89+
"input_data2": np.random.randn(*input_shape).astype(np.float32),
90+
}
91+
self.run_transpose_compare(["res"], feed_dict, model_proto, remaining_transpose_num=1)
92+
6693
def test_transpose_relu(self):
6794
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
6895
node2 = helper.make_node("Relu", ["Y"], ["Z"], name="relu")

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,14 @@ def _identity_handler(self, trans, node):
426426

427427
def _concat_handler(self, trans, node):
428428
if self._handle_node_having_branches(node):
429-
node.set_attr("axis", 1)
429+
perm = trans.get_attr("perm").ints
430+
axis_attr = node.get_attr("axis")
431+
if axis_attr:
432+
axis = axis_attr.i # fix
433+
else:
434+
axis = 0
435+
new_axis = perm[axis]
436+
node.set_attr("axis", new_axis)
430437
return True
431438
return False
432439

0 commit comments

Comments
 (0)