Skip to content

Commit 194f3dc

Browse files
authored
Support fp16 in GPU impl of fused_elemwise_activation_op. (#20636) (#20655)
* Support fp16 in fused_elemwise_activation_op. * Fix unit testing in ONLY-CPU mode.
1 parent ddcb81d commit 194f3dc

File tree

3 files changed

+89
-14
lines changed

3 files changed

+89
-14
lines changed

paddle/fluid/operators/fused/fused_elemwise_activation_op.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,15 @@ REGISTER_OP_CUDA_KERNEL(
2020
ops::FusedElemwiseActivationKernel<paddle::platform::CUDADeviceContext,
2121
float>,
2222
ops::FusedElemwiseActivationKernel<paddle::platform::CUDADeviceContext,
23-
double>);
23+
double>,
24+
ops::FusedElemwiseActivationKernel<paddle::platform::CUDADeviceContext,
25+
paddle::platform::float16>);
2426

2527
REGISTER_OP_CUDA_KERNEL(
2628
fused_elemwise_activation_grad,
2729
ops::FusedElemwiseActivationGradKernel<paddle::platform::CUDADeviceContext,
2830
float>,
2931
ops::FusedElemwiseActivationGradKernel<paddle::platform::CUDADeviceContext,
30-
double>);
32+
double>,
33+
ops::FusedElemwiseActivationGradKernel<paddle::platform::CUDADeviceContext,
34+
paddle::platform::float16>);

paddle/fluid/operators/math/functors.h

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include "paddle/fluid/operators/math.h"
18+
1719
namespace paddle {
1820
namespace operators {
1921
namespace math {
@@ -40,8 +42,8 @@ struct AddFunctor {
4042

4143
template <typename T>
4244
struct AddGradFunctor {
43-
inline HOSTDEVICE T Dx(T x, T y) { return 1; }
44-
inline HOSTDEVICE T Dy(T x, T y) { return 1; }
45+
inline HOSTDEVICE T Dx(T x, T y) { return static_cast<T>(1.); }
46+
inline HOSTDEVICE T Dy(T x, T y) { return static_cast<T>(1.); }
4547
};
4648

4749
template <typename T>
@@ -68,14 +70,22 @@ struct ScaleGradFunctor {
6870

6971
template <typename T>
7072
struct ReluFunctor {
71-
inline HOSTDEVICE T operator()(T x) { return x * (x > 0); }
73+
inline HOSTDEVICE T operator()(T x) {
74+
return x * (x > static_cast<T>(0) ? static_cast<T>(1) : static_cast<T>(0));
75+
}
7276
};
7377

7478
template <typename T>
7579
struct ReluGradFunctor {
76-
inline HOSTDEVICE T UseX(T x) { return x > 0 ? 1 : 0; }
77-
inline HOSTDEVICE T UseOut(T out) { return out > 0 ? 1 : 0; }
78-
inline HOSTDEVICE T UseXAndOut(T x, T out) { return out > 0 ? 1 : 0; }
80+
inline HOSTDEVICE T UseX(T x) {
81+
return x > static_cast<T>(0) ? static_cast<T>(1) : static_cast<T>(0);
82+
}
83+
inline HOSTDEVICE T UseOut(T out) {
84+
return out > static_cast<T>(0) ? static_cast<T>(1) : static_cast<T>(0);
85+
}
86+
inline HOSTDEVICE T UseXAndOut(T x, T out) {
87+
return out > static_cast<T>(0) ? static_cast<T>(1) : static_cast<T>(0);
88+
}
7989
};
8090

8191
template <typename T>
@@ -84,9 +94,9 @@ struct TanhFunctor {
8494
const T kMax = static_cast<T>(13);
8595
inline HOSTDEVICE T operator()(T x) {
8696
// y = 2 / (1 + e^-2x) - 1
87-
T t0 = 2 * x;
97+
T t0 = static_cast<T>(2) * x;
8898
T t1 = (t0 < kMin) ? kMin : ((t0 > kMax) ? kMax : t0);
89-
return static_cast<T>(2) / (static_cast<T>(1) + std::exp(-t1)) -
99+
return static_cast<T>(2) / (static_cast<T>(1) + real_exp(-t1)) -
90100
static_cast<T>(1);
91101
}
92102
};
@@ -107,7 +117,7 @@ struct SigmoidFunctor {
107117
inline HOSTDEVICE T operator()(T x) {
108118
// y = 1 / (1 + e^-x)
109119
T tmp = (x < kMin) ? kMin : ((x > kMax) ? kMax : x);
110-
return static_cast<T>(1) / (static_cast<T>(1) + std::exp(-tmp));
120+
return static_cast<T>(1) / (static_cast<T>(1) + real_exp(-tmp));
111121
}
112122
};
113123

python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,24 @@
3333
# TestFusedElementwiseActivationOp_channelwise_add
3434

3535

36-
def create_test_class(test_case, callback, attrs):
36+
def create_test_class(test_case,
37+
callback,
38+
attrs,
39+
dtype=np.float32,
40+
grad_chek=True):
3741
class TestFusedElementwiseActivationOp_base(OpTest):
3842
def setUp(self):
3943
self.op_type = "fused_elemwise_activation"
40-
self.dtype = np.float32
44+
self.dtype = dtype
4145
self.axis = -1
4246

4347
self.init_input()
4448
self.init_output()
4549
self.init_attr()
4650

51+
self.out = self.out.astype(self.dtype)
52+
self.intermediate_out = self.intermediate_out.astype(self.dtype)
53+
4754
self.inputs = {
4855
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
4956
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
@@ -71,16 +78,25 @@ def init_attr(self):
7178
self.attrs[key] = attrs[key]
7279

7380
def test_check_output(self):
74-
self.check_output()
81+
if self.dtype == np.float16 and core.is_compiled_with_cuda():
82+
place = core.CUDAPlace(0)
83+
if core.is_float16_supported(place):
84+
self.check_output_with_place(place, atol=1e-3)
85+
else:
86+
self.check_output()
7587

7688
# FIXME(zcd): the intermediate_out_grad is not checked.
7789
def test_check_grad_normal(self):
90+
if not grad_chek:
91+
return
7892
if self.attrs["save_intermediate_out"]:
7993
self.check_grad(['X', 'Y'], ['Out'], max_relative_error=0.005)
8094
else:
8195
self.check_grad(['X', 'Y'], ['Out'], max_relative_error=0.005)
8296

8397
def test_check_grad_ingore_x(self):
98+
if not grad_chek:
99+
return
84100
if self.attrs["save_intermediate_out"]:
85101
self.check_grad(
86102
['Y'], ['Out'],
@@ -93,6 +109,8 @@ def test_check_grad_ingore_x(self):
93109
no_grad_set=set("X"))
94110

95111
def test_check_grad_ingore_y(self):
112+
if not grad_chek:
113+
return
96114
if self.attrs["save_intermediate_out"]:
97115
self.check_grad(
98116
['X'], ['Out'],
@@ -307,11 +325,29 @@ def mul_scale_func(x, y, x_bcast, y_bcast, scale, mode=0):
307325
'functor_list': ["scale", "elementwise_add"],
308326
'save_intermediate_out': save_intermediate_out,
309327
})
328+
create_test_class(
329+
'scale_add_fp16' + suffix,
330+
scale_add_func, {
331+
'scale': scale,
332+
'functor_list': ["scale", "elementwise_add"],
333+
'save_intermediate_out': save_intermediate_out,
334+
},
335+
dtype=np.float16,
336+
grad_chek=False)
310337
create_test_class('add_scale' + suffix, add_scale_func, {
311338
'scale': scale,
312339
'functor_list': ["elementwise_add", "scale"],
313340
'save_intermediate_out': save_intermediate_out,
314341
})
342+
create_test_class(
343+
'add_scale_fp16' + suffix,
344+
add_scale_func, {
345+
'scale': scale,
346+
'functor_list': ["elementwise_add", "scale"],
347+
'save_intermediate_out': save_intermediate_out,
348+
},
349+
dtype=np.float16,
350+
grad_chek=False)
315351
create_test_class('add_relu' + suffix, add_relu_func, {
316352
'functor_list': ["elementwise_add", "relu"],
317353
'save_intermediate_out': save_intermediate_out,
@@ -320,11 +356,36 @@ def mul_scale_func(x, y, x_bcast, y_bcast, scale, mode=0):
320356
'functor_list': ["relu", "elementwise_add"],
321357
'save_intermediate_out': save_intermediate_out,
322358
})
359+
create_test_class(
360+
'add_relu_fp16' + suffix,
361+
add_relu_func, {
362+
'functor_list': ["elementwise_add", "relu"],
363+
'save_intermediate_out': save_intermediate_out,
364+
},
365+
dtype=np.float16,
366+
grad_chek=False)
367+
create_test_class(
368+
'relu_add_fp16' + suffix,
369+
relu_add_func, {
370+
'functor_list': ["relu", "elementwise_add"],
371+
'save_intermediate_out': save_intermediate_out,
372+
},
373+
dtype=np.float16,
374+
grad_chek=False)
323375
create_test_class('mul_scale' + suffix, mul_scale_func, {
324376
'scale': scale,
325377
'functor_list': ["elementwise_mul", "scale"],
326378
'save_intermediate_out': save_intermediate_out,
327379
})
380+
create_test_class(
381+
'mul_scale' + suffix,
382+
mul_scale_func, {
383+
'scale': scale,
384+
'functor_list': ["elementwise_mul", "scale"],
385+
'save_intermediate_out': save_intermediate_out,
386+
},
387+
dtype=np.float16,
388+
grad_chek=False)
328389

329390
if __name__ == '__main__':
330391
unittest.main()

0 commit comments

Comments
 (0)