@@ -477,6 +477,28 @@ def test_transpose_merge(self, input_shape1, input_shape2, perm):
477
477
self .run_transpose_compare (["OUT" ], {"X" : np .random .randn (* input_shape1 ).astype (np .float32 )},
478
478
model_proto , remaining_transpose_num = 1 )
479
479
480
+
481
+ @parameterized .expand ([
482
+ ((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ], [0 , 3 , 1 , 2 ]),
483
+ ((2 , 3 , 4 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ], [0 , 4 , 1 , 2 , 3 ]),
484
+ ])
485
+ def test_transpose_mul_as_square (self , shape , perm_input , perm_output ):
486
+ node0 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = perm_input , name = "trans" )
487
+ node1 = helper .make_node ("Mul" , ["Y" , "Y" ], ["Z" ], name = "mul" )
488
+ node2 = helper .make_node ("Transpose" , ["Z" ], ["OUT" ], perm = perm_output , name = "trans_1" )
489
+
490
+ graph = helper .make_graph (
491
+ [node0 , node1 , node2 ],
492
+ "transpose-mul-as-sqr-test" ,
493
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , shape )],
494
+ [helper .make_tensor_value_info ("OUT" , TensorProto .FLOAT , shape )],
495
+ )
496
+
497
+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
498
+ self .run_transpose_compare (["OUT" ], {"X" : np .random .randn (* shape ).astype (np .float32 )},
499
+ model_proto , remaining_transpose_num = 0 )
500
+
501
+
480
502
@parameterized .expand ([
481
503
((2 , 3 , 4 , 5 ), [0 , 2 , 3 , 1 ]),
482
504
((2 , 3 , 4 , 5 , 6 ), [0 , 2 , 3 , 4 , 1 ]),
0 commit comments