Skip to content

Commit 6c71c1f

Browse files
ClementineYibing Liu
authored andcommitted
Add activation gelu (#14569)
1 parent 6648f5e commit 6c71c1f

File tree

4 files changed

+65
-1
lines changed

4 files changed

+65
-1
lines changed

paddle/fluid/operators/activation_op.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,13 @@ Relu Activation Operator.
149149
150150
)DOC";
151151

152+
UNUSED constexpr char GeluDoc[] = R"DOC(
153+
Gelu Activation Operator.
154+
155+
$out = \\frac{1 + erf(\\frac{x}{\\sqrt{2}})}{2} x$
156+
157+
)DOC";
158+
152159
UNUSED constexpr char TanhDoc[] = R"DOC(
153160
Tanh Activation Operator.
154161
@@ -472,6 +479,7 @@ REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
472479
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
473480
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
474481
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
482+
REGISTER_ACTIVATION_OP_MAKER(Gelu, GeluDoc);
475483
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
476484
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
477485
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
@@ -489,6 +497,7 @@ REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);
489497

490498
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sigmoid, sigmoid);
491499
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu, relu);
500+
REGISTER_ACTIVATION_OP_GRAD_MAKER(Gelu, gelu);
492501
REGISTER_ACTIVATION_OP_GRAD_MAKER(Exp, exp);
493502
REGISTER_ACTIVATION_OP_GRAD_MAKER(Tanh, tanh);
494503
REGISTER_ACTIVATION_OP_GRAD_MAKER(Ceil, ceil);
@@ -525,6 +534,7 @@ namespace ops = paddle::operators;
525534
__macro(Round, round); \
526535
__macro(Log, log); \
527536
__macro(Square, square); \
537+
__macro(Gelu, gelu); \
528538
__macro(BRelu, brelu); \
529539
__macro(Pow, pow); \
530540
__macro(STanh, stanh); \

paddle/fluid/operators/activation_op.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ limitations under the License. */
1616
#include <utility>
1717
#include <vector>
1818

19+
#include <cmath>
20+
#ifndef _USE_MATH_DEFINES
21+
#define _USE_MATH_DEFINES
22+
#endif
23+
1924
#include "paddle/fluid/framework/eigen.h"
2025
#include "paddle/fluid/framework/op_registry.h"
2126
#include "paddle/fluid/operators/detail/safe_ref.h"
@@ -212,6 +217,31 @@ struct ReluGradFunctor : public BaseActivationFunctor<T> {
212217
}
213218
};
214219

220+
// gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
221+
template <typename T>
222+
struct GeluFunctor : public BaseActivationFunctor<T> {
223+
template <typename Device, typename X, typename Out>
224+
void operator()(Device d, X x, Out out) const {
225+
auto temp =
226+
((x * static_cast<T>(M_SQRT1_2)).erf()).template cast<T>().eval();
227+
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
228+
}
229+
};
230+
231+
template <typename T>
232+
struct GeluGradFunctor : BaseActivationFunctor<T> {
233+
bool Inplace() const { return IsInplace("gelu"); }
234+
template <typename Device, typename X, typename Out, typename dOut,
235+
typename dX>
236+
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
237+
auto temp = (static_cast<T>(0.5 * M_2_SQRTPI * M_SQRT1_2) * x *
238+
((-static_cast<T>(0.5) * x.square()).exp()))
239+
.template cast<T>()
240+
.eval();
241+
dx.device(d) = dout * (out / x + temp);
242+
}
243+
};
244+
215245
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
216246
template <typename T>
217247
struct TanhFunctor : public BaseActivationFunctor<T> {
@@ -877,6 +907,7 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
877907
__macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
878908
__macro(exp, ExpFunctor, ExpGradFunctor); \
879909
__macro(relu, ReluFunctor, ReluGradFunctor); \
910+
__macro(gelu, GeluFunctor, GeluGradFunctor); \
880911
__macro(tanh, TanhFunctor, TanhGradFunctor); \
881912
__macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
882913
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \

paddle/fluid/platform/float16.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,11 @@ HOSTDEVICE inline float16 exp(const float16& a) {
10391039
return float16(::expf(static_cast<float>(a)));
10401040
}
10411041

1042+
template <>
1043+
HOSTDEVICE inline float16 erf(const float16& a) {
1044+
return float16(::erff(static_cast<float>(a)));
1045+
}
1046+
10421047
template <>
10431048
HOSTDEVICE inline float16 log(const float16& a) {
10441049
return float16(::logf(static_cast<float>(a)));

python/paddle/fluid/tests/unittests/test_activation_op.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import numpy as np
1919
import paddle.fluid.core as core
2020
from op_test import OpTest
21-
from scipy.special import expit
21+
from scipy.special import expit, erf
2222

2323

2424
class TestActivation(OpTest):
@@ -295,6 +295,23 @@ def test_check_grad(self):
295295
self.check_grad(['X'], 'Out', max_relative_error=0.007)
296296

297297

298+
class TestGelu(TestActivation):
299+
def setUp(self):
300+
self.op_type = "gelu"
301+
self.init_dtype()
302+
303+
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
304+
out = 0.5 * x * (1.0 + erf(x / np.sqrt(2.0)))
305+
306+
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
307+
self.outputs = {'Out': out}
308+
309+
def test_check_grad(self):
310+
if self.dtype == np.float16:
311+
return
312+
self.check_grad(['X'], 'Out', max_relative_error=0.007)
313+
314+
298315
class TestBRelu(TestActivation):
299316
def setUp(self):
300317
self.op_type = "brelu"
@@ -628,6 +645,7 @@ def test_check_grad(self):
628645
create_test_act_fp16_class(TestSin)
629646
create_test_act_fp16_class(TestRound, grad_check=False)
630647
create_test_act_fp16_class(TestRelu)
648+
create_test_act_fp16_class(TestGelu)
631649
create_test_act_fp16_class(TestBRelu)
632650
create_test_act_fp16_class(TestRelu6)
633651
create_test_act_fp16_class(TestSoftRelu)

0 commit comments

Comments
 (0)