@@ -90,26 +90,47 @@ def test_transpose_with_concat(self):
90
90
}
91
91
self .run_transpose_compare (["res" ], feed_dict , model_proto , remaining_transpose_num = 1 )
92
92
93
- def test_transpose_with_add (self ):
93
+ def test_transpose_with_add1 (self ):
94
94
# when transpose follows with a broadcasting op
95
95
# reshape is needed when switching transpose with this op and op need broadcast its inputs
96
96
node1 = helper .make_node ("Transpose" , ["input_data1" ], ["Y" ], perm = [0 , 2 , 3 , 1 ], name = "trans" )
97
97
node2 = helper .make_node ("Add" , ["Y" , "input_data2" ], ["Z" ], name = "add" )
98
+ node3 = helper .make_node ("Transpose" , ["Z" ], ["res" ], perm = [0 , 3 , 1 , 2 ], name = "trans2" )
98
99
99
100
graph = helper .make_graph (
100
- [node1 , node2 ],
101
+ [node1 , node2 , node3 ],
101
102
"transpose_with_shape" ,
102
103
[helper .make_tensor_value_info ("input_data1" , TensorProto .FLOAT , (2 , 3 , 4 , 5 )),
103
104
helper .make_tensor_value_info ("input_data2" , TensorProto .FLOAT , (3 ,)),
104
105
],
105
- [helper .make_tensor_value_info ("Z " , TensorProto .FLOAT , [ 2 , 4 , 5 , 3 ] )],
106
+ [helper .make_tensor_value_info ("res " , TensorProto .FLOAT , ( 2 , 3 , 4 , 5 ) )],
106
107
)
107
108
108
109
model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
109
110
feed_dict = {"input_data1" : np .random .randn (2 , 3 , 4 , 5 ).astype (np .float32 ),
110
111
"input_data2" : np .random .randn (3 ).astype (np .float32 ),
111
112
}
112
- self .run_transpose_compare (["Z" ], feed_dict , model_proto , remaining_transpose_num = 1 )
113
+ self .run_transpose_compare (["res" ], feed_dict , model_proto , remaining_transpose_num = 0 )
114
+
115
+ def test_transpose_with_add2 (self ):
116
+ node1 = helper .make_node ("Transpose" , ["input_data1" ], ["Y" ], perm = [0 , 2 , 3 , 1 ], name = "trans" )
117
+ node2 = helper .make_node ("Add" , ["Y" , "input_data2" ], ["Z" ], name = "add" )
118
+ node3 = helper .make_node ("Transpose" , ["Z" ], ["res" ], perm = [0 , 3 , 1 , 2 ], name = "trans2" )
119
+
120
+ graph = helper .make_graph (
121
+ [node1 , node2 , node3 ],
122
+ "transpose_with_shape" ,
123
+ [helper .make_tensor_value_info ("input_data1" , TensorProto .FLOAT , (2 , 3 , 4 , 5 )),
124
+ helper .make_tensor_value_info ("input_data2" , TensorProto .FLOAT , (2 , 4 , 5 , 3 )),
125
+ ],
126
+ [helper .make_tensor_value_info ("res" , TensorProto .FLOAT , (2 , 3 , 4 , 5 ))],
127
+ )
128
+
129
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" )
130
+ feed_dict = {"input_data1" : np .random .randn (2 , 3 , 4 , 5 ).astype (np .float32 ),
131
+ "input_data2" : np .random .randn (2 , 4 , 5 , 3 ).astype (np .float32 ),
132
+ }
133
+ self .run_transpose_compare (["res" ], feed_dict , model_proto , remaining_transpose_num = 1 )
113
134
114
135
def test_transpose_relu (self ):
115
136
node1 = helper .make_node ("Transpose" , ["X" ], ["Y" ], perm = [0 , 2 , 3 , 1 ], name = "trans_1" )
0 commit comments