@@ -211,6 +211,31 @@ def test_transpose_leaky_relu(self, shape, perm_input, perm_output):
211
211
self .run_transpose_compare (["Z1" ], {"X" : np .random .randn (* shape ).astype (np .float32 )},
212
212
model_proto , remaining_transpose_num = 0 )
213
213
214
+ @parameterized .expand ([
215
+ ((2 , 3 , 4 ), [2 , 0 , 1 ], [1 , 2 , 0 ]),
216
+ ((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
217
+ ((2 , 3 , 4 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
218
+ ])
219
+ def test_transpose_with_prelu (self , input_shape , perm_input , perm_output ):
220
+ node1 = helper .make_node ("Transpose" , ["input_data1" ], ["Y" ], perm = perm_input , name = "trans" )
221
+ node2 = helper .make_node ("PRelu" , ["Y" , "input_data2" ], ["Z" ], name = "add" )
222
+ node3 = helper .make_node ("Transpose" , ["Z" ], ["res" ], perm = perm_output , name = "trans2" )
223
+
224
+ graph = helper .make_graph (
225
+ [node1 , node2 , node3 ],
226
+ "transpose_with_shape" ,
227
+ [helper .make_tensor_value_info ("input_data1" , TensorProto .FLOAT , input_shape ),
228
+ helper .make_tensor_value_info ("input_data2" , TensorProto .FLOAT , (input_shape [1 ],)),
229
+ ],
230
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , input_shape )],
231
+ )
232
+
233
+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
234
+ feed_dict = {"input_data1" : np .random .randn (* input_shape ).astype (np .float32 ),
235
+ "input_data2" : np .random .randn (input_shape [1 ]).astype (np .float32 ),
236
+ }
237
+ self .run_transpose_compare (["res" ], feed_dict , model_proto , remaining_transpose_num = 0 )
238
+
214
239
@parameterized .expand ([
215
240
((2 , 3 , 4 ), [2 , 0 , 1 ], [1 , 2 , 0 ]),
216
241
((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
0 commit comments