@@ -188,6 +188,43 @@ def test_trans_output_as_graph_outputs(self):
188
188
189
189
self .assertTrue (trans_cnt == 1 , msg = "Expect 1 Transpose ops left, but actually " + str (trans_cnt ) + " left" )
190
190
191
+ def test_trans_can_be_replaced_with_reshape1 (self ):
192
+ # test trans-NHWC
193
+ input_shapes_np = [(2 , 3 , 4 , 1 ), (2 , 1 , 1 , 4 ), (2 , 3 , 4 , 1 )]
194
+ input_shapes = [(2 , 3 , 4 , 1 ), (2 , 1 , 1 , 4 ), (2 , - 1 , - 1 , 1 )]
195
+ perm = (0 , 3 , 1 , 2 )
196
+ for input_shape_np , input_shape in zip (input_shapes_np , input_shapes ):
197
+ result_shape = [input_shape [i ] for i in perm ]
198
+ node1 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = perm , name = "trans" )
199
+ graph = helper .make_graph (
200
+ [node1 ],
201
+ "test_trans_can_be_replaced_with_reshape" ,
202
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , input_shape )],
203
+ [helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , result_shape )],
204
+ )
205
+
206
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
207
+ self .run_transpose_compare (["Y" ], {"X" : np .random .randn (* input_shape_np ).astype (np .float32 )},
208
+ model_proto , remaining_transpose_num = 0 )
209
+
210
+ def test_trans_can_be_replaced_with_reshape2 (self ):
211
+ # test trans-NCHW
212
+ input_shapes_np = [(2 , 1 , 3 , 4 ), (2 , 4 , 1 , 1 ), (2 , 1 , 3 , 4 )]
213
+ input_shapes = [(2 , 1 , 3 , 4 ), (2 , 4 , 1 , 1 ), (2 , 1 , - 1 , - 1 )]
214
+ perm = (0 , 2 , 3 , 1 )
215
+ for input_shape_np , input_shape in zip (input_shapes_np , input_shapes ):
216
+ result_shape = [input_shape [i ] for i in perm ]
217
+ node1 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = perm , name = "trans" )
218
+ graph = helper .make_graph (
219
+ [node1 ],
220
+ "test_trans_can_be_replaced_with_reshape" ,
221
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , input_shape )],
222
+ [helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , result_shape )],
223
+ )
224
+
225
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
226
+ self .run_transpose_compare (["Y" ], {"X" : np .random .randn (* input_shape_np ).astype (np .float32 )},
227
+ model_proto , remaining_transpose_num = 0 )
191
228
# Tranpose Optimizer Tests End
192
229
193
230
# Identity Optimizer Tests Start
0 commit comments