@@ -687,6 +687,28 @@ def test_trans_can_be_replaced_with_reshape2(self):
687
687
self .run_transpose_compare (["Y" ], {"X" : np .random .randn (* input_shape_np ).astype (np .float32 )},
688
688
model_proto , remaining_transpose_num = 0 )
689
689
690
+ def test_two_transposes_switch_with_mul (self ):
691
+ const_node = self ._make_onnx_const (np .array (10 , dtype = np .float32 ), "const_10" )
692
+ node0 = helper .make_node ("Transpose" , ["u1" ], ["v1" ], perm = [0 , 2 , 3 , 1 ], name = "trans_0" )
693
+ node1 = helper .make_node ("Transpose" , ["u2" ], ["v2" ], perm = [0 , 2 , 3 , 1 ], name = "trans_1" )
694
+
695
+ node2 = helper .make_node ("Mul" , ["v1" , "v2" ], ["x" ], name = "mul_1" )
696
+ node3 = helper .make_node ("Mul" , ["x" , const_node .output [0 ]], ["y" ], name = "mul_2" )
697
+ node4 = helper .make_node ("Transpose" , ["y" ], ["res" ], perm = [0 , 3 , 1 , 2 ], name = "trans_3" )
698
+
699
+ graph = helper .make_graph (
700
+ [const_node , node0 , node1 , node2 , node3 , node4 ],
701
+ "test-transpose-mul" ,
702
+ [helper .make_tensor_value_info ("u1" , TensorProto .FLOAT , (1 , 6 , 8 , 9 )),
703
+ helper .make_tensor_value_info ("u2" , TensorProto .FLOAT , (1 , 6 , 8 , 9 ))],
704
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , (1 , 6 , 8 , 9 ))],
705
+ )
706
+
707
+ model_proto = self .make_model (graph , producer_name = "onnx-tests" )
708
+ self .run_transpose_compare (["res" ], {"u1" : np .random .randn (1 , 6 , 8 , 9 ).astype (np .float32 ),
709
+ "u2" : np .random .randn (1 , 6 , 8 , 9 ).astype (np .float32 )},
710
+ model_proto , remaining_transpose_num = 0 )
711
+
690
712
# Tranpose Optimizer Tests End
691
713
692
714
# Identity Optimizer Tests Start
0 commit comments