Skip to content

Commit 436808c

Browse files
authored
elu support alpha < 0 (#37316) (#37437)
1 parent 58a5113 commit 436808c

File tree

7 files changed

+196
-42
lines changed

7 files changed

+196
-42
lines changed

paddle/fluid/operators/activation_op.cc

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,22 @@ Applies the following element-wise computation on the input according to
560560
}
561561
};
562562

563+
template <typename T>
564+
class ELUGradOpMaker : public framework::SingleGradOpMaker<T> {
565+
public:
566+
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
567+
568+
protected:
569+
void Apply(GradOpPtr<T> op) const override {
570+
op->SetType("elu_grad");
571+
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
572+
op->SetInput("Out", this->Output("Out"));
573+
op->SetInput("X", this->Input("X"));
574+
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
575+
op->SetAttrMap(this->Attrs());
576+
}
577+
};
578+
563579
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
564580
public:
565581
void Make() override {
@@ -1233,13 +1249,11 @@ REGISTER_OP_CPU_KERNEL(
12331249
/* ========================================================================== */
12341250

12351251
/* ======================== elu register ============================ */
1236-
REGISTER_OPERATOR(
1237-
elu, ops::ActivationOp, ops::ELUOpMaker, ops::ActivationOpInferVarType,
1238-
ops::ActivationGradOpMaker<ops::ELUGradFunctor<float>::FwdDeps(),
1239-
paddle::framework::OpDesc>,
1240-
ops::ActivationGradOpMaker<ops::ELUGradFunctor<float>::FwdDeps(),
1241-
paddle::imperative::OpBase>,
1242-
ops::ActFwdInplaceInferer);
1252+
REGISTER_OPERATOR(elu, ops::ActivationOp, ops::ELUOpMaker,
1253+
ops::ActivationOpInferVarType,
1254+
ops::ELUGradOpMaker<paddle::framework::OpDesc>,
1255+
ops::ELUGradOpMaker<paddle::imperative::OpBase>,
1256+
ops::ActFwdInplaceInferer);
12431257
REGISTER_OPERATOR(elu_grad, ops::ActivationOpGrad,
12441258
ops::ActivationGradOpInplaceInferer,
12451259
ops::ELUDoubleGradMaker<paddle::framework::OpDesc>,
@@ -1249,7 +1263,14 @@ REGISTER_OPERATOR(
12491263
ops::ActivationOpDoubleGrad<ops::ELUGradFunctor<float>::FwdDeps()>,
12501264
ops::ActivationDoubleGradOpInplaceInferer);
12511265

1252-
REGISTER_ACTIVATION_CPU_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor);
1266+
REGISTER_OP_CPU_KERNEL(elu,
1267+
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
1268+
ops::ELUFunctor<float>>,
1269+
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
1270+
ops::ELUFunctor<double>>);
1271+
REGISTER_OP_CPU_KERNEL(
1272+
elu_grad, ops::ELUGradKernel<paddle::platform::CPUDeviceContext, float>,
1273+
ops::ELUGradKernel<paddle::platform::CPUDeviceContext, double>);
12531274
REGISTER_OP_CPU_KERNEL(
12541275
elu_grad_grad, ops::ELUDoubleGradKernel<plat::CPUDeviceContext,
12551276
ops::ELUGradGradFunctor<float>>,

paddle/fluid/operators/activation_op.cu

Lines changed: 76 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,11 +1161,12 @@ struct CudaELUFunctor : public BaseActivationFunctor<T> {
11611161
return {{"alpha", &alpha}};
11621162
}
11631163

1164-
// elu(x) = max(0, x) + min(0, alpha * (exp(x) - 1))
1164+
// elu(x) = x, if x > 0
1165+
// elu(x) = alpha * (e^x - 1), if x <= 0
11651166
__device__ __forceinline__ T operator()(const T& arg_x) const {
11661167
CT x = static_cast<CT>(arg_x);
11671168
CT temp = static_cast<CT>(alpha) * (exp(x) - one);
1168-
CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
1169+
CT res = x > zero ? x : temp;
11691170
return static_cast<T>(res);
11701171
}
11711172
};
@@ -1174,34 +1175,84 @@ template <typename T>
11741175
struct CudaELUGradFunctor : public BaseActivationFunctor<T> {
11751176
using MPType = typename details::MPTypeTrait<T>::Type;
11761177
MPType zero = static_cast<MPType>(0.0f);
1177-
MPType one = static_cast<MPType>(1.0f);
11781178
float alpha;
11791179

11801180
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
11811181
return {{"alpha", &alpha}};
11821182
}
11831183

1184-
// dx = dout, if alpha > 0 and x > 0
1185-
// dx = dout * alpha * x.exp(), if alpha > 0 and x <= 0
1186-
// dx = dout * (1 + alpha * x.exp()), if alpha <= 0 and x > 0
1187-
// dx = 0, if alpha <= 0 and x <=0
1184+
// case 1: alpha >= 0
1185+
// dx = dout, if out > 0
1186+
// dx = dout * (out + alpha), if out <= 0
11881187
__device__ __forceinline__ T operator()(const T& arg_dout,
1188+
const T& arg_out) const {
1189+
MPType dout = static_cast<MPType>(arg_dout);
1190+
MPType out = static_cast<MPType>(arg_out);
1191+
MPType a = static_cast<MPType>(alpha);
1192+
MPType out_pos = static_cast<MPType>(out > zero);
1193+
MPType out_neg = static_cast<MPType>(out <= zero);
1194+
return static_cast<T>(dout * (out_pos + out_neg * (out + a)));
1195+
}
1196+
1197+
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
1198+
};
1199+
1200+
template <typename T>
1201+
struct CudaELUGradNegativeAlphaFunctor : public BaseActivationFunctor<T> {
1202+
using MPType = typename details::MPTypeTrait<T>::Type;
1203+
MPType zero = static_cast<MPType>(0.0f);
1204+
float alpha;
1205+
1206+
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
1207+
return {{"alpha", &alpha}};
1208+
}
1209+
1210+
// case 2: alpha < 0
1211+
// dx = dout, if x > 0
1212+
// dx = dout * (out + alpha), if x <=0
1213+
__device__ __forceinline__ T operator()(const T& arg_dout, const T& arg_out,
11891214
const T& arg_x) const {
11901215
MPType dout = static_cast<MPType>(arg_dout);
1216+
MPType out = static_cast<MPType>(arg_out);
11911217
MPType x = static_cast<MPType>(arg_x);
11921218
MPType a = static_cast<MPType>(alpha);
1193-
MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
1194-
MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
1195-
MPType temp_x_pos = static_cast<MPType>(x > zero);
1196-
MPType temp_x_neg = static_cast<MPType>(x <= zero);
1197-
return static_cast<T>(
1198-
dout * (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * a * exp(x) +
1199-
temp_a_neg * temp_x_pos * (one + a * exp(x))));
1219+
MPType x_pos = static_cast<MPType>(x > zero);
1220+
MPType x_neg = static_cast<MPType>(x <= zero);
1221+
return static_cast<T>(dout * (x_pos + x_neg * (out + a)));
12001222
}
12011223

12021224
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
12031225
};
12041226

1227+
template <typename DeviceContext, typename T>
1228+
class ELUGradCudaKernel : public framework::OpKernel<T> {
1229+
public:
1230+
void Compute(const framework::ExecutionContext& ctx) const {
1231+
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
1232+
auto* out = ctx.Input<framework::Tensor>("Out");
1233+
auto* x = ctx.Input<framework::Tensor>("X");
1234+
auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
1235+
d_x->mutable_data<T>(ctx.GetPlace());
1236+
const float alpha = ctx.Attr<float>("alpha");
1237+
1238+
auto& dev_ctx = ctx.device_context<DeviceContext>();
1239+
std::vector<const framework::Tensor*> ins = {d_out, out};
1240+
std::vector<framework::Tensor*> outs = {d_x};
1241+
if (alpha > 0) {
1242+
CudaELUGradFunctor<T> functor;
1243+
functor.alpha = alpha;
1244+
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
1245+
dev_ctx, ins, &outs, functor);
1246+
} else {
1247+
CudaELUGradNegativeAlphaFunctor<T> functor;
1248+
functor.alpha = alpha;
1249+
ins.push_back(x);
1250+
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
1251+
dev_ctx, ins, &outs, functor);
1252+
}
1253+
}
1254+
};
1255+
12051256
template <typename DeviceContext, typename Functor>
12061257
class ActivationCudaKernel
12071258
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
@@ -1330,7 +1381,17 @@ REGISTER_OP_CUDA_KERNEL(
13301381
/* ========================================================================== */
13311382

13321383
/* ======================== elu register ============================ */
1333-
REGISTER_ACTIVATION_CUDA_KERNEL(elu, ELU, CudaELUFunctor, CudaELUGradFunctor);
1384+
REGISTER_OP_CUDA_KERNEL(
1385+
elu, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
1386+
ops::CudaELUFunctor<float>>,
1387+
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
1388+
ops::CudaELUFunctor<double>>,
1389+
ops::ActivationCudaKernel<plat::CUDADeviceContext,
1390+
ops::CudaELUFunctor<plat::float16>>);
1391+
REGISTER_OP_CUDA_KERNEL(
1392+
elu_grad, ops::ELUGradCudaKernel<plat::CUDADeviceContext, float>,
1393+
ops::ELUGradCudaKernel<plat::CUDADeviceContext, double>,
1394+
ops::ELUGradCudaKernel<plat::CUDADeviceContext, plat::float16>);
13341395

13351396
REGISTER_OP_CUDA_KERNEL(
13361397
elu_grad_grad, ops::ELUDoubleGradKernel<plat::CUDADeviceContext,

paddle/fluid/operators/activation_op.h

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,25 +1311,70 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
13111311
template <typename Device, typename X, typename Out, typename dOut,
13121312
typename dX>
13131313
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
1314-
auto temp_a_pos = static_cast<T>(alpha > 0);
1315-
auto temp_a_neg = static_cast<T>(alpha <= 0);
1316-
auto temp_x_pos = (x > static_cast<T>(0)).template cast<T>();
1317-
auto temp_x_neg = (x <= static_cast<T>(0)).template cast<T>();
1318-
1319-
// dx = dout, if alpha > 0 and x > 0
1320-
// dx = dout * alpha * x.exp(), if alpha > 0 and x <= 0
1321-
// dx = dout * (1 + alpha * x.exp()), if alpha <= 0 and x > 0
1322-
// dx = 0, if alpha <= 0 and x <=0
1323-
dx.device(d) =
1324-
dout * temp_a_pos * temp_x_pos +
1325-
dout * static_cast<T>(alpha) * x.exp() * temp_a_pos * temp_x_neg +
1326-
dout * (static_cast<T>(1) + static_cast<T>(alpha) * x.exp()) *
1327-
temp_a_neg * temp_x_pos;
1314+
// case 1: alpha >= 0
1315+
// dx = dout, if out > 0
1316+
// dx = dout * (out + alpha), if out <= 0
1317+
dx.device(d) = (out > static_cast<T>(0))
1318+
.select(dout, dout * (out + static_cast<T>(alpha)));
1319+
}
1320+
1321+
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
1322+
};
1323+
1324+
template <typename T>
1325+
struct ELUGradNegativeAlphaFunctor : public BaseActivationFunctor<T> {
1326+
float alpha;
1327+
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
1328+
return {{"alpha", &alpha}};
1329+
}
1330+
template <typename Device, typename X, typename Out, typename dOut,
1331+
typename dX>
1332+
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
1333+
// case 2: alpha < 0
1334+
// dx = dout, if x > 0
1335+
// dx = dout * (out + alpha), if x <=0
1336+
dx.device(d) = (x > static_cast<T>(0))
1337+
.select(dout, dout * static_cast<T>(alpha) * x.exp());
13281338
}
13291339

13301340
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
13311341
};
13321342

1343+
template <typename DeviceContext, typename T>
1344+
class ELUGradKernel : public framework::OpKernel<T> {
1345+
public:
1346+
void Compute(const framework::ExecutionContext& context) const override {
1347+
auto* X = context.Input<framework::Tensor>("X");
1348+
auto* Out = context.Input<framework::Tensor>("Out");
1349+
auto* dOut =
1350+
context.Input<framework::Tensor>(framework::GradVarName("Out"));
1351+
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
1352+
const float alpha = context.Attr<float>("alpha");
1353+
dX->mutable_data<T>(context.GetPlace());
1354+
1355+
auto x = framework::EigenVector<T>::Flatten(
1356+
GET_DATA_SAFELY(X, "Input", "X", "elu_grad"));
1357+
auto out = framework::EigenVector<T>::Flatten(
1358+
GET_DATA_SAFELY(Out, "Input", "Out", "elu_grad"));
1359+
auto dout = framework::EigenVector<T>::Flatten(
1360+
GET_DATA_SAFELY(dOut, "Input", "dOut", "elu_grad"));
1361+
auto dx = framework::EigenVector<T>::Flatten(
1362+
GET_DATA_SAFELY(dX, "Output", "dX", "elu_grad"));
1363+
auto* place =
1364+
context.template device_context<DeviceContext>().eigen_device();
1365+
1366+
if (alpha > 0) {
1367+
ELUGradFunctor<T> functor;
1368+
functor.alpha = alpha;
1369+
functor(*place, x, out, dout, dx);
1370+
} else {
1371+
ELUGradNegativeAlphaFunctor<T> functor;
1372+
functor.alpha = alpha;
1373+
functor(*place, x, out, dout, dx);
1374+
}
1375+
}
1376+
};
1377+
13331378
// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
13341379
template <typename T>
13351380
struct PowFunctor : public BaseActivationFunctor<T> {

paddle/fluid/operators/inplace_abn_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class InplaceABNActivation {
104104
auto temp2 = (y * temp / static_cast<T>(alpha) + static_cast<T>(1)).log();
105105
x.device(d) = (y * temp1 + temp2).template cast<T>();
106106

107-
ELUGradFunctor<T> functor;
107+
ELUGradNegativeAlphaFunctor<T> functor;
108108
compute(ctx, &functor, d, x, y, dy, dx);
109109
} else {
110110
PADDLE_THROW(

python/paddle/fluid/tests/unittests/test_activation_op.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,7 +1742,7 @@ def test_errors(self):
17421742

17431743

17441744
def elu(x, alpha):
1745-
out_ref = np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1))
1745+
out_ref = np.where(x > 0, x, alpha * (np.exp(x) - 1))
17461746
return out_ref.astype(x.dtype)
17471747

17481748

@@ -1753,7 +1753,7 @@ def setUp(self):
17531753

17541754
np.random.seed(1024)
17551755
x = np.random.uniform(-3, 3, [10, 12]).astype(self.dtype)
1756-
alpha = 1.
1756+
alpha = self.get_alpha()
17571757
out = elu(x, alpha)
17581758
# Note: unlike other Relu extensions, point 0 on standard ELU function (i.e. alpha = 1)
17591759
# is differentiable, so we can skip modifications like x[np.abs(x) < 0.005] = 0.02 here
@@ -1766,6 +1766,14 @@ def test_check_grad(self):
17661766
return
17671767
self.check_grad(['X'], 'Out')
17681768

1769+
def get_alpha(self):
1770+
return 1.
1771+
1772+
1773+
class TestELUAlpha(TestELU):
1774+
def get_alpha(self):
1775+
return -0.2
1776+
17691777

17701778
class TestELUAPI(unittest.TestCase):
17711779
# test paddle.nn.ELU, paddle.nn.functional.elu
@@ -1832,6 +1840,12 @@ class TestELUInplaceAPI(TestELUAPI):
18321840
def executed_api(self):
18331841
self.elu = F.elu_
18341842

1843+
def test_alpha_error(self):
1844+
paddle.disable_static(self.place)
1845+
x = paddle.to_tensor(self.x_np)
1846+
self.assertRaises(Exception, F.elu_, x, -0.2)
1847+
paddle.enable_static()
1848+
18351849

18361850
class TestReciprocal(TestActivation):
18371851
def setUp(self):

python/paddle/nn/functional/activation.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,13 @@ def elu(x, alpha=1.0, name=None):
3737
3838
.. math::
3939
40-
elu(x) = max(0, x) + min(0, \alpha * (e^{x}-1))
40+
elu(x)=
41+
\left\{
42+
\begin{array}{lcl}
43+
x,& &\text{if } \ x > 0 \\
44+
alpha * (e^{x} - 1),& &\text{if } \ x <= 0
45+
\end{array}
46+
\right.
4147
4248
Parameters:
4349
x (Tensor): The input Tensor with data type float32, float64.
@@ -80,6 +86,7 @@ def elu_(x, alpha=1.0, name=None):
8086
Inplace version of ``elu`` API, the output Tensor will be inplaced with input ``x``.
8187
Please refer to :ref:`api_nn_cn_elu`.
8288
"""
89+
assert alpha >= 0., "elu_ only support alpha >= 0, please use elu instead."
8390
return _C_ops.elu_(x, 'alpha', alpha)
8491

8592

python/paddle/nn/layer/activation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@ class ELU(Layer):
3131
3232
.. math::
3333
34-
ELU(x) = max(0, x) + min(0, \alpha * (e^{x}-1))
34+
ELU(x)=
35+
\left\{
36+
\begin{array}{lcl}
37+
x,& &\text{if } \ x > 0 \\
38+
alpha * (e^{x} - 1),& &\text{if } \ x <= 0
39+
\end{array}
40+
\right.
3541
3642
Parameters:
3743
alpha (float, optional): The 'alpha' value of the ELU formulation. Default is 1.0.

0 commit comments

Comments
 (0)