@@ -134,29 +134,60 @@ def test_out(self):
134
134
result1 , = exe .run (feed = {"data1" : input ,
135
135
"data2" : input2 },
136
136
fetch_list = [result_squeeze ])
137
- self .assertTrue (np .allclose (input1 , result1 ))
137
+ self .assertTrue (np .array_equal (input1 , result1 ))
138
+ self .assertEqual (input1 .shape , result1 .shape )
138
139
139
140
140
141
class API_TestDyUnsqueeze (unittest .TestCase ):
141
142
def test_out (self ):
142
143
with fluid .dygraph .guard ():
143
144
input_1 = np .random .random ([5 , 1 , 10 ]).astype ("int32" )
144
- input1 = np .squeeze (input_1 , axis = 1 )
145
+ input1 = np .expand_dims (input_1 , axis = 1 )
145
146
input = fluid .dygraph .to_variable (input_1 )
146
147
output = paddle .unsqueeze (input , axis = [1 ])
147
148
out_np = output .numpy ()
148
- self .assertTrue (np .allclose (input1 , out_np ))
149
+ self .assertTrue (np .array_equal (input1 , out_np ))
150
+ self .assertEqual (input1 .shape , out_np .shape )
149
151
150
152
151
153
class API_TestDyUnsqueeze2 (unittest .TestCase ):
152
154
def test_out (self ):
153
155
with fluid .dygraph .guard ():
154
- input_1 = np .random .random ([5 , 1 , 10 ]).astype ("int32" )
155
- input1 = np .squeeze ( input_1 , axis = 1 )
156
- input = fluid .dygraph .to_variable (input_1 )
156
+ input1 = np .random .random ([5 , 10 ]).astype ("int32" )
157
+ out1 = np .expand_dims ( input1 , axis = 1 )
158
+ input = fluid .dygraph .to_variable (input1 )
157
159
output = paddle .unsqueeze (input , axis = 1 )
158
160
out_np = output .numpy ()
159
- self .assertTrue (np .allclose (input1 , out_np ))
161
+ self .assertTrue (np .array_equal (out1 , out_np ))
162
+ self .assertEqual (out1 .shape , out_np .shape )
163
+
164
+
165
+ class API_TestDyUnsqueezeAxisTensor (unittest .TestCase ):
166
+ def test_out (self ):
167
+ with fluid .dygraph .guard ():
168
+ input1 = np .random .random ([5 , 10 ]).astype ("int32" )
169
+ out1 = np .expand_dims (input1 , axis = 1 )
170
+ input = fluid .dygraph .to_variable (input1 )
171
+ output = paddle .unsqueeze (input , axis = paddle .to_tensor ([1 ]))
172
+ out_np = output .numpy ()
173
+ self .assertTrue (np .array_equal (out1 , out_np ))
174
+ self .assertEqual (out1 .shape , out_np .shape )
175
+
176
+
177
+ class API_TestDyUnsqueezeAxisTensorList (unittest .TestCase ):
178
+ def test_out (self ):
179
+ with fluid .dygraph .guard ():
180
+ input1 = np .random .random ([5 , 10 ]).astype ("int32" )
181
+ # Actually, expand_dims supports tuple since version 1.18.0
182
+ out1 = np .expand_dims (input1 , axis = 1 )
183
+ out1 = np .expand_dims (out1 , axis = 2 )
184
+ input = fluid .dygraph .to_variable (input1 )
185
+ output = paddle .unsqueeze (
186
+ fluid .dygraph .to_variable (input1 ),
187
+ axis = [paddle .to_tensor ([1 ]), paddle .to_tensor ([2 ])])
188
+ out_np = output .numpy ()
189
+ self .assertTrue (np .array_equal (out1 , out_np ))
190
+ self .assertEqual (out1 .shape , out_np .shape )
160
191
161
192
162
193
if __name__ == "__main__" :
0 commit comments