|
12 | 12 | from tf2onnx import utils
|
13 | 13 | from tf2onnx.graph import GraphUtil
|
14 | 14 | from backend_test_base import Tf2OnnxBackendTestBase
|
15 |
| -from common import unittest_main |
| 15 | +from common import unittest_main, group_nodes_by_type |
16 | 16 |
|
17 | 17 |
|
18 | 18 | # pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
|
@@ -162,6 +162,33 @@ def test_transpose_with_identity(self):
|
162 | 162 | self.run_transpose_compare(["Z"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
|
163 | 163 | model_proto, remaining_transpose_num=1)
|
164 | 164 |
|
| 165 | + def test_trans_output_as_graph_outputs(self): |
| 166 | + """ |
| 167 | + If transpose's output is graph's output, don't optimize it. |
| 168 | + """ |
| 169 | + trans = helper.make_node("Transpose", ["X"], ["Y"], name="trans", perm=[0, 2, 3, 1]) |
| 170 | + graph_proto = helper.make_graph( |
| 171 | + [trans], |
| 172 | + "trans-to-graph-output", |
| 173 | + [helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5))], |
| 174 | + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, (2, 4, 5, 3))], |
| 175 | + ) |
| 176 | + |
| 177 | + graph = GraphUtil.create_graph_from_onnx_graph(graph_proto) |
| 178 | + # remove identity to graph output |
| 179 | + identity_op = graph.get_node_by_output(graph.outputs[0]) |
| 180 | + graph.outputs = [identity_op.input[0]] |
| 181 | + graph.remove_node(identity_op.name) |
| 182 | + |
| 183 | + optimized_graph = GraphUtil.optimize_graph(graph, "onnx-tests") |
| 184 | + |
| 185 | + self.assertTrue(optimized_graph, msg="graph after optimizer should not be None") |
| 186 | + |
| 187 | + trans_cnt = len(group_nodes_by_type(optimized_graph)["Transpose"]) |
| 188 | + |
| 189 | + self.assertTrue(trans_cnt == 1, msg="Expect 1 Transpose ops left, but actually " + |
| 190 | + str(trans_cnt) + " left") |
| 191 | + |
165 | 192 | # Tranpose Optimizer Tests End
|
166 | 193 |
|
167 | 194 | # Identity Optimizer Tests Start
|
@@ -396,6 +423,5 @@ def test_duplicated_need_multiple_run(self):
|
396 | 423 | op_type="Log", remaining_op_num=3)
|
397 | 424 | # Merge Duplicated Nodes Optimizer Tests End
|
398 | 425 |
|
399 |
| - |
400 | 426 | if __name__ == "__main__":
|
401 | 427 | unittest_main()
|
0 commit comments