33
33
# TestFusedElementwiseActivationOp_channelwise_add
34
34
35
35
36
- def create_test_class (test_case , callback , attrs ):
36
+ def create_test_class (test_case ,
37
+ callback ,
38
+ attrs ,
39
+ dtype = np .float32 ,
40
+ grad_chek = True ):
37
41
class TestFusedElementwiseActivationOp_base (OpTest ):
38
42
def setUp (self ):
39
43
self .op_type = "fused_elemwise_activation"
40
- self .dtype = np . float32
44
+ self .dtype = dtype
41
45
self .axis = - 1
42
46
43
47
self .init_input ()
44
48
self .init_output ()
45
49
self .init_attr ()
46
50
51
+ self .out = self .out .astype (self .dtype )
52
+ self .intermediate_out = self .intermediate_out .astype (self .dtype )
53
+
47
54
self .inputs = {
48
55
'X' : OpTest .np_dtype_to_fluid_dtype (self .x ),
49
56
'Y' : OpTest .np_dtype_to_fluid_dtype (self .y )
@@ -71,16 +78,25 @@ def init_attr(self):
71
78
self .attrs [key ] = attrs [key ]
72
79
73
80
def test_check_output (self ):
74
- self .check_output ()
81
+ if self .dtype == np .float16 and core .is_compiled_with_cuda ():
82
+ place = core .CUDAPlace (0 )
83
+ if core .is_float16_supported (place ):
84
+ self .check_output_with_place (place , atol = 1e-3 )
85
+ else :
86
+ self .check_output ()
75
87
76
88
# FIXME(zcd): the intermediate_out_grad is not checked.
77
89
def test_check_grad_normal (self ):
90
+ if not grad_chek :
91
+ return
78
92
if self .attrs ["save_intermediate_out" ]:
79
93
self .check_grad (['X' , 'Y' ], ['Out' ], max_relative_error = 0.005 )
80
94
else :
81
95
self .check_grad (['X' , 'Y' ], ['Out' ], max_relative_error = 0.005 )
82
96
83
97
def test_check_grad_ingore_x (self ):
98
+ if not grad_chek :
99
+ return
84
100
if self .attrs ["save_intermediate_out" ]:
85
101
self .check_grad (
86
102
['Y' ], ['Out' ],
@@ -93,6 +109,8 @@ def test_check_grad_ingore_x(self):
93
109
no_grad_set = set ("X" ))
94
110
95
111
def test_check_grad_ingore_y (self ):
112
+ if not grad_chek :
113
+ return
96
114
if self .attrs ["save_intermediate_out" ]:
97
115
self .check_grad (
98
116
['X' ], ['Out' ],
@@ -307,11 +325,29 @@ def mul_scale_func(x, y, x_bcast, y_bcast, scale, mode=0):
307
325
'functor_list' : ["scale" , "elementwise_add" ],
308
326
'save_intermediate_out' : save_intermediate_out ,
309
327
})
328
+ create_test_class (
329
+ 'scale_add_fp16' + suffix ,
330
+ scale_add_func , {
331
+ 'scale' : scale ,
332
+ 'functor_list' : ["scale" , "elementwise_add" ],
333
+ 'save_intermediate_out' : save_intermediate_out ,
334
+ },
335
+ dtype = np .float16 ,
336
+ grad_chek = False )
310
337
create_test_class ('add_scale' + suffix , add_scale_func , {
311
338
'scale' : scale ,
312
339
'functor_list' : ["elementwise_add" , "scale" ],
313
340
'save_intermediate_out' : save_intermediate_out ,
314
341
})
342
+ create_test_class (
343
+ 'add_scale_fp16' + suffix ,
344
+ add_scale_func , {
345
+ 'scale' : scale ,
346
+ 'functor_list' : ["elementwise_add" , "scale" ],
347
+ 'save_intermediate_out' : save_intermediate_out ,
348
+ },
349
+ dtype = np .float16 ,
350
+ grad_chek = False )
315
351
create_test_class ('add_relu' + suffix , add_relu_func , {
316
352
'functor_list' : ["elementwise_add" , "relu" ],
317
353
'save_intermediate_out' : save_intermediate_out ,
@@ -320,11 +356,36 @@ def mul_scale_func(x, y, x_bcast, y_bcast, scale, mode=0):
320
356
'functor_list' : ["relu" , "elementwise_add" ],
321
357
'save_intermediate_out' : save_intermediate_out ,
322
358
})
359
+ create_test_class (
360
+ 'add_relu_fp16' + suffix ,
361
+ add_relu_func , {
362
+ 'functor_list' : ["elementwise_add" , "relu" ],
363
+ 'save_intermediate_out' : save_intermediate_out ,
364
+ },
365
+ dtype = np .float16 ,
366
+ grad_chek = False )
367
+ create_test_class (
368
+ 'relu_add_fp16' + suffix ,
369
+ relu_add_func , {
370
+ 'functor_list' : ["relu" , "elementwise_add" ],
371
+ 'save_intermediate_out' : save_intermediate_out ,
372
+ },
373
+ dtype = np .float16 ,
374
+ grad_chek = False )
323
375
create_test_class ('mul_scale' + suffix , mul_scale_func , {
324
376
'scale' : scale ,
325
377
'functor_list' : ["elementwise_mul" , "scale" ],
326
378
'save_intermediate_out' : save_intermediate_out ,
327
379
})
380
+ create_test_class (
381
+ 'mul_scale' + suffix ,
382
+ mul_scale_func , {
383
+ 'scale' : scale ,
384
+ 'functor_list' : ["elementwise_mul" , "scale" ],
385
+ 'save_intermediate_out' : save_intermediate_out ,
386
+ },
387
+ dtype = np .float16 ,
388
+ grad_chek = False )
328
389
329
390
if __name__ == '__main__' :
330
391
unittest .main ()
0 commit comments