Skip to content

Commit 073394e

Browse files
committed
enhance transpose optimizer and add related tests
1 parent 1fcbc53 commit 073394e

File tree

2 files changed

+55
-4
lines changed

2 files changed

+55
-4
lines changed

tests/test_optimizers.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
class OptimizerTests(Tf2OnnxBackendTestBase):
2020
"""Run original model proto and modified model proto with onnxruntime, compare the results."""
2121

22-
def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto, debug=False, rtol=1e-07):
22+
def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
23+
remaining_transpose_num=None, debug=False, rtol=1e-07):
2324
origin_model_path = self.save_onnx_model(origin_proto, onnx_feed_dict, postfix="_origin")
2425

2526
new_proto = GraphUtil.opt_transposes_with_model_proto(origin_proto)
@@ -32,6 +33,8 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
3233
current = GraphUtil.get_node_count_from_onnx_graph(new_proto.graph)
3334

3435
self.assertTrue(current["Transpose"] < previous["Transpose"], msg="transpose ops count not changed")
36+
if remaining_transpose_num is not None:
37+
self.assertTrue(current["Transpose"] == remaining_transpose_num, msg="some transpose ops left unexpected")
3538

3639
if self.config.is_onnxruntime_backend:
3740
expected = self.run_onnxruntime(origin_model_path, onnx_feed_dict, output_names_with_port)
@@ -58,7 +61,7 @@ def test_relu(self):
5861

5962
model_proto = helper.make_model(graph, producer_name="onnx-tests")
6063
self.run_and_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
61-
model_proto)
64+
model_proto, remaining_transpose_num=0)
6265

6366
def test_leaky_relu(self):
6467
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
@@ -74,7 +77,7 @@ def test_leaky_relu(self):
7477

7578
model_proto = helper.make_model(graph, producer_name="onnx-tests")
7679
self.run_and_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
77-
model_proto)
80+
model_proto, remaining_transpose_num=0)
7881

7982
def test_max(self):
8083
const_1_val = [2.0]
@@ -102,7 +105,38 @@ def test_max(self):
102105

103106
model_proto = helper.make_model(graph, producer_name="onnx-tests")
104107
self.run_and_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
105-
model_proto)
108+
model_proto, remaining_transpose_num=0)
109+
110+
def test_transpose_merge(self):
111+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans")
112+
node1 = helper.make_node("Transpose", ["X"], ["Y_1"], perm=[0, 2, 3, 1], name="trans_1")
113+
node2 = helper.make_node("Mul", ["Y", "Y_1"], ["OUT"], name="mul")
114+
115+
graph = helper.make_graph(
116+
[node0, node1, node2],
117+
"transpose-merge-test",
118+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5))],
119+
[helper.make_tensor_value_info("OUT", TensorProto.FLOAT, (2, 4, 5, 3))],
120+
)
121+
122+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
123+
self.run_and_compare(["OUT"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
124+
model_proto, remaining_transpose_num=1)
125+
126+
def test_transpose_with_shape(self):
127+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans")
128+
node2 = helper.make_node("Shape", ["Y"], ["Z"], name="shape")
129+
130+
graph = helper.make_graph(
131+
[node1, node2],
132+
"transpose_with_shape",
133+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5))],
134+
[helper.make_tensor_value_info("Z", TensorProto.INT64, [4])],
135+
)
136+
137+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
138+
self.run_and_compare(["Z"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
139+
model_proto, remaining_transpose_num=0)
106140

107141

108142
if __name__ == "__main__":

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def _initialize_handlers(self):
187187
"Pad": self._pad_handler,
188188
"ReduceMean": self._reducemean_handler,
189189
"Relu": self._simple_through_handler,
190+
"Shape": self._shape_handler,
190191
"Slice": self._slice_handler,
191192
"Split": self._split_handler,
192193
"Tanh": self._simple_through_handler,
@@ -460,3 +461,19 @@ def _slice_handler(self, trans, node):
460461

461462
def _simple_through_handler(self, trans, node):
462463
return self._switch_transpose_and_node(node, trans)
464+
465+
def _shape_handler(self, trans, node):
466+
# input > trans > shape can be changed into input > shape > gather
467+
if not self._transpose_has_single_consumer_node([trans]):
468+
return False
469+
470+
output_shape = self._g.get_shape(node.output[0])
471+
output_dtype = self._g.get_dtype(node.output[0])
472+
self._g.remove_node(trans.name)
473+
self._g.remove_node(node.name)
474+
shape_node = self._g.make_node("Shape", [trans.input[0]])
475+
const_node = self._g.make_const("Const", np.array(trans.get_attr("perm").ints))
476+
gather_node = self._g.make_node("Gather", [shape_node.output[0], const_node.output[0]], outputs=node.output)
477+
self._g.set_shape(gather_node.output[0], output_shape)
478+
self._g.set_dtype(gather_node.output[0], output_dtype)
479+
return True

0 commit comments

Comments
 (0)