Skip to content

Commit 89aca79

Browse files
authored
Merge pull request #543 from zhijxu-MS/enhance_trans_opt
Enhance trans opt
2 parents 9b249c6 + b182821 commit 89aca79

File tree

2 files changed

+63
-26
lines changed

2 files changed

+63
-26
lines changed

tests/test_optimizers.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,43 @@ def test_trans_output_as_graph_outputs(self):
188188

189189
self.assertTrue(trans_cnt == 1, msg="Expect 1 Transpose ops left, but actually " + str(trans_cnt) + " left")
190190

191+
def test_trans_can_be_replaced_with_reshape1(self):
192+
# test trans-NHWC
193+
input_shapes_np = [(2, 3, 4, 1), (2, 1, 1, 4), (2, 3, 4, 1)]
194+
input_shapes = [(2, 3, 4, 1), (2, 1, 1, 4), (2, -1, -1, 1)]
195+
perm = (0, 3, 1, 2)
196+
for input_shape_np, input_shape in zip(input_shapes_np, input_shapes):
197+
result_shape = [input_shape[i] for i in perm]
198+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
199+
graph = helper.make_graph(
200+
[node1],
201+
"test_trans_can_be_replaced_with_reshape",
202+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
203+
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, result_shape)],
204+
)
205+
206+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
207+
self.run_transpose_compare(["Y"], {"X": np.random.randn(*input_shape_np).astype(np.float32)},
208+
model_proto, remaining_transpose_num=0)
209+
210+
def test_trans_can_be_replaced_with_reshape2(self):
211+
# test trans-NCHW
212+
input_shapes_np = [(2, 1, 3, 4), (2, 4, 1, 1), (2, 1, 3, 4)]
213+
input_shapes = [(2, 1, 3, 4), (2, 4, 1, 1), (2, 1, -1, -1)]
214+
perm = (0, 2, 3, 1)
215+
for input_shape_np, input_shape in zip(input_shapes_np, input_shapes):
216+
result_shape = [input_shape[i] for i in perm]
217+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
218+
graph = helper.make_graph(
219+
[node1],
220+
"test_trans_can_be_replaced_with_reshape",
221+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
222+
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, result_shape)],
223+
)
224+
225+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
226+
self.run_transpose_compare(["Y"], {"X": np.random.randn(*input_shape_np).astype(np.float32)},
227+
model_proto, remaining_transpose_num=0)
191228
# Tranpose Optimizer Tests End
192229

193230
# Identity Optimizer Tests Start

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -74,40 +74,40 @@ def pre_optimize_action(self):
7474
self._g.topological_sort(self._g.get_nodes())
7575

7676
def post_optimize_action(self):
77+
def _calculate_new_shape(graph, op):
78+
input_shape = graph.get_shape(op.input[0])
79+
if input_shape.count(-1) <= 1:
80+
if is_nchw_transpose(op):
81+
new_shape = [input_shape[0], input_shape[3], input_shape[1], input_shape[2]]
82+
else:
83+
new_shape = [input_shape[0], input_shape[2], input_shape[3], input_shape[1]]
84+
return graph.make_const(utils.make_name("new_shape"), np.array(new_shape, dtype=np.int64)).output[0]
85+
86+
# reshape requires tha output shape can only contain one -1, if not some extra op needed.
87+
input_shape = graph.make_node("Shape", [op.input[0]]).output[0]
88+
if is_nchw_transpose(op):
89+
indice = graph.make_const(utils.make_name("indice"), np.array([0, 3, 1, 2])).output[0]
90+
else:
91+
indice = graph.make_const(utils.make_name("indice"), np.array([0, 2, 3, 1])).output[0]
92+
93+
return graph.make_node("Gather", [input_shape, indice]).output[0]
94+
7795
nodes = self.nodes
7896
# if channel==1 or height==width==1, replace transpose with reshape
97+
# replacing trans with reshape is because transpose will copy data even if this transpose doesn't nothing
7998
for op in nodes:
8099
if op.type == "Transpose":
81100
input_shape = self._g.get_shape(op.input[0])
82101
if not input_shape:
83102
continue
84-
# reshape only supports one dime is -1
85-
if input_shape.count(-1) > 1:
86-
continue
87103

88-
new_shape = []
89-
# when transpose is NHWC_TO_NCHW
90-
if is_nchw_transpose(op) and (input_shape[3] == 1 or (input_shape[1] == 1 and input_shape[2] == 1)):
91-
new_shape = [input_shape[0], input_shape[3], input_shape[1], input_shape[2]]
92-
# when transpose is NCHW_TO_NHWC
93-
if is_nhwc_transpose(op) and (input_shape[1] == 1 or (input_shape[2] == 1 and input_shape[3] == 1)):
94-
new_shape = [input_shape[0], input_shape[2], input_shape[3], input_shape[1]]
95-
if new_shape:
96-
out_nodes = self._g.find_output_consumers(op.output[0])
97-
need_insert_reshape = False
98-
for out_node in out_nodes:
99-
if out_node.type != "Reshape":
100-
need_insert_reshape = True
101-
if need_insert_reshape:
102-
op_name = utils.make_name("reshape")
103-
shape_name = utils.make_name(op_name)
104-
self._g.make_const(shape_name, np.array(new_shape, dtype=np.int64))
105-
self._g.remove_node(op.name)
106-
self._g.make_node("Reshape", inputs=[op.input[0], shape_name], outputs=op.output,
107-
name=op_name)
108-
else:
109-
self._remove_useless_tranpose(op)
110-
self._g.topological_sort(self._g.get_nodes())
104+
if (is_nchw_transpose(op) and (input_shape[3] == 1 or (input_shape[1:3] == [1, 1])))\
105+
or (is_nhwc_transpose(op) and (input_shape[1] == 1 or (input_shape[2:4] == [1, 1]))):
106+
new_shape = _calculate_new_shape(self._g, op)
107+
# replace transpose with reshape
108+
self._g.remove_node(op.name)
109+
self._g.make_node("Reshape", [op.input[0], new_shape], name=op.name, outputs=op.output)
110+
self._g.topological_sort(self._g.get_nodes())
111111

112112
def merge_duplicated_transposes(self):
113113
# strategy used in previous procedure is to move transpose nodes down if possible,

0 commit comments

Comments
 (0)