@@ -63,6 +63,33 @@ def check_transpose_perm(self, model_proto, expected_perm):
63
63
perm = list (node .attribute [0 ].ints )
64
64
self .assertEqual (perm , expected_perm )
65
65
66
+ def test_transpose_with_concat (self ):
67
+ input_shape = (2 , 3 , 4 , 5 )
68
+ perm = [0 , 3 , 1 , 2 ]
69
+ input_shape_with_trans = [input_shape [i ] for i in perm ]
70
+ for axis in [0 , 1 , 2 , 3 ]:
71
+ output_before_trans = list (input_shape )
72
+ output_before_trans [axis ] *= 2
73
+ output_shape = [output_before_trans [i ] for i in [0 , 3 , 1 , 2 ]]
74
+ node1 = helper .make_node ("Transpose" , ["input_data1" ], ["Y" ], perm = [0 , 2 , 3 , 1 ], name = "trans" )
75
+ node2 = helper .make_node ("Concat" , ["Y" , "input_data2" ], ["Z" ], axis = axis , name = "concat" )
76
+ node3 = helper .make_node ("Transpose" , ["Z" ], ["res" ], perm = [0 , 3 , 1 , 2 ], name = "trans2" )
77
+
78
+ graph = helper .make_graph (
79
+ [node1 , node2 , node3 ],
80
+ "test_transpose_with_concat" ,
81
+ [helper .make_tensor_value_info ("input_data1" , TensorProto .FLOAT , input_shape_with_trans ),
82
+ helper .make_tensor_value_info ("input_data2" , TensorProto .FLOAT , input_shape ),
83
+ ],
84
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , output_shape )],
85
+ )
86
+
87
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
88
+ feed_dict = {"input_data1" : np .random .randn (* input_shape_with_trans ).astype (np .float32 ),
89
+ "input_data2" : np .random .randn (* input_shape ).astype (np .float32 ),
90
+ }
91
+ self .run_transpose_compare (["res" ], feed_dict , model_proto , remaining_transpose_num = 1 )
92
+
66
93
def test_transpose_relu (self ):
67
94
node1 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = [0 , 2 , 3 , 1 ], name = "trans_1" )
68
95
node2 = helper .make_node ("Relu" , ["Y" ], ["Z" ], name = "relu" )
0 commit comments