19
19
class OptimizerTests (Tf2OnnxBackendTestBase ):
20
20
"""Run original model proto and modified model proto with onnxruntime, compare the results."""
21
21
22
- def run_and_compare (self , output_names_with_port , onnx_feed_dict , origin_proto , debug = False , rtol = 1e-07 ):
22
+ def run_and_compare (self , output_names_with_port , onnx_feed_dict , origin_proto ,
23
+ remaining_transpose_num = None , debug = False , rtol = 1e-07 ):
23
24
origin_model_path = self .save_onnx_model (origin_proto , onnx_feed_dict , postfix = "_origin" )
24
25
25
26
new_proto = GraphUtil .opt_transposes_with_model_proto (origin_proto )
@@ -32,6 +33,8 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
32
33
current = GraphUtil .get_node_count_from_onnx_graph (new_proto .graph )
33
34
34
35
self .assertTrue (current ["Transpose" ] < previous ["Transpose" ], msg = "transpose ops count not changed" )
36
+ if remaining_transpose_num is not None :
37
+ self .assertTrue (current ["Transpose" ] == remaining_transpose_num , msg = "some transpose ops left unexpected" )
35
38
36
39
if self .config .is_onnxruntime_backend :
37
40
expected = self .run_onnxruntime (origin_model_path , onnx_feed_dict , output_names_with_port )
@@ -58,7 +61,7 @@ def test_relu(self):
58
61
59
62
model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
60
63
self .run_and_compare (["Z1" ], {"X" : np .random .randn (2 , 3 , 4 , 5 ).astype (np .float32 )},
61
- model_proto )
64
+ model_proto , remaining_transpose_num = 0 )
62
65
63
66
def test_leaky_relu (self ):
64
67
node1 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = [0 , 2 , 3 , 1 ], name = "trans_1" )
@@ -74,7 +77,7 @@ def test_leaky_relu(self):
74
77
75
78
model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
76
79
self .run_and_compare (["Z1" ], {"X" : np .random .randn (2 , 3 , 4 , 5 ).astype (np .float32 )},
77
- model_proto )
80
+ model_proto , remaining_transpose_num = 0 )
78
81
79
82
def test_max (self ):
80
83
const_1_val = [2.0 ]
@@ -102,7 +105,38 @@ def test_max(self):
102
105
103
106
model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
104
107
self .run_and_compare (["Z1" ], {"X" : np .random .randn (2 , 3 , 4 , 5 ).astype (np .float32 )},
105
- model_proto )
108
+ model_proto , remaining_transpose_num = 0 )
109
+
110
+ def test_transpose_merge (self ):
111
+ node0 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = [0 , 2 , 3 , 1 ], name = "trans" )
112
+ node1 = helper .make_node ("Transpose" , ["X" ], ["Y_1" ], perm = [0 , 2 , 3 , 1 ], name = "trans_1" )
113
+ node2 = helper .make_node ("Mul" , ["Y" , "Y_1" ], ["OUT" ], name = "mul" )
114
+
115
+ graph = helper .make_graph (
116
+ [node0 , node1 , node2 ],
117
+ "transpose-merge-test" ,
118
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (2 , 3 , 4 , 5 ))],
119
+ [helper .make_tensor_value_info ("OUT" , TensorProto .FLOAT , (2 , 4 , 5 , 3 ))],
120
+ )
121
+
122
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
123
+ self .run_and_compare (["OUT" ], {"X" : np .random .randn (2 , 3 , 4 , 5 ).astype (np .float32 )},
124
+ model_proto , remaining_transpose_num = 1 )
125
+
126
+ def test_transpose_with_shape (self ):
127
+ node1 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = [0 , 2 , 3 , 1 ], name = "trans" )
128
+ node2 = helper .make_node ("Shape" , ["Y" ], ["Z" ], name = "shape" )
129
+
130
+ graph = helper .make_graph (
131
+ [node1 , node2 ],
132
+ "transpose_with_shape" ,
133
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (2 , 3 , 4 , 5 ))],
134
+ [helper .make_tensor_value_info ("Z" , TensorProto .INT64 , [4 ])],
135
+ )
136
+
137
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
138
+ self .run_and_compare (["Z" ], {"X" : np .random .randn (2 , 3 , 4 , 5 ).astype (np .float32 )},
139
+ model_proto , remaining_transpose_num = 0 )
106
140
107
141
108
142
if __name__ == "__main__" :
0 commit comments