@@ -231,56 +231,5 @@ def test_check_grad(self):
231
231
pass
232
232
233
233
234
- class TestMinMaxWithIndexPlace (unittest .TestCase ):
235
- """min/max_with_index has no CPU version, so when CUDA is not available,
236
- we skip all the above test. A runtime error will be emitted if min/max_with_index
237
- is called on CPU, this unit test tries capturing it.
238
- """
239
-
240
- def init (self ):
241
- self .input_shape = [30 , 10 , 10 ]
242
- self .data = np .random .randn (30 , 10 , 10 )
243
-
244
- def setUp (self ):
245
- self .init ()
246
-
247
- def cpu_place (self ):
248
- self .place = core .CPUPlace ()
249
-
250
- def test_api_static_cpu_err_handling_1 (self ):
251
- self .cpu_place ()
252
- with (
253
- self .assertRaises (RuntimeError ),
254
- paddle .static .program_guard (paddle .static .Program ()),
255
- ):
256
- input = paddle .static .data (
257
- name = "input" , shape = self .input_shape , dtype = "float64"
258
- )
259
- output = max_with_index (input , dim = 0 )
260
- exe = paddle .static .Executor (self .place )
261
- result = exe .run (
262
- paddle .static .default_main_program (),
263
- feed = {'input' : self .data },
264
- fetch_list = [output ],
265
- )
266
-
267
- def test_api_static_cpu_err_handling_2 (self ):
268
- self .cpu_place ()
269
- with (
270
- self .assertRaises (RuntimeError ),
271
- paddle .static .program_guard (paddle .static .Program ()),
272
- ):
273
- input = paddle .static .data (
274
- name = "input" , shape = self .input_shape , dtype = "float32"
275
- )
276
- output = min_with_index (input , dim = - 2 , keepdim = True )
277
- exe = paddle .static .Executor (self .place )
278
- result = exe .run (
279
- paddle .static .default_main_program (),
280
- feed = {'input' : self .data .astype (np .float32 )},
281
- fetch_list = [output ],
282
- )
283
-
284
-
285
234
if __name__ == "__main__" :
286
235
unittest .main ()
0 commit comments