Skip to content

Commit 113c026

Browse files
authored
Swish activation operator (#6358)
1 parent 3a0a458 commit 113c026

File tree

3 files changed

+65
-0
lines changed

3 files changed

+65
-0
lines changed

paddle/operators/activation_op.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,22 @@ It is recommended to use the defaults for this activation.
506506
}
507507
};
508508

509+
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
510+
public:
511+
SwishOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
512+
: OpProtoAndCheckerMaker(proto, op_checker) {
513+
AddInput("X", "Input of Swish operator");
514+
AddOutput("Y", "Output of Swish operator");
515+
AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
516+
AddComment(R"DOC(
517+
Swish Activation Operator.
518+
519+
$$y = \frac{x}{1 + e^{- \beta x}}$$
520+
521+
)DOC");
522+
}
523+
};
524+
509525
} // namespace operators
510526
} // namespace paddle
511527

@@ -592,6 +608,9 @@ REGISTER_OP(thresholded_relu, ops::ActivationOp, ops::ThresholdedReluOpMaker,
592608
REGISTER_OP(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker,
593609
hard_sigmoid_grad, ops::ActivationOpGrad);
594610

611+
REGISTER_OP(swish, ops::ActivationOp, ops::SwishOpMaker, swish_grad,
612+
ops::ActivationOpGrad);
613+
595614
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
596615
REGISTER_OP_CPU_KERNEL( \
597616
act_type, \

paddle/operators/activation_op.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,35 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
700700
}
701701
};
702702

703+
template <typename T>
704+
struct SwishFunctor : public BaseActivationFunctor<T> {
705+
float beta;
706+
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
707+
return {{"beta", &beta}};
708+
}
709+
710+
template <typename Device, typename X, typename Y>
711+
void operator()(Device d, X x, Y y) const {
712+
y.device(d) = x / (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
713+
}
714+
};
715+
716+
template <typename T>
717+
struct SwishGradFunctor : public BaseActivationFunctor<T> {
718+
float beta;
719+
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
720+
return {{"beta", &beta}};
721+
}
722+
723+
template <typename Device, typename X, typename Y, typename dY, typename dX>
724+
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
725+
auto temp1 = static_cast<T>(1) /
726+
(static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
727+
auto temp2 = temp1 * (static_cast<T>(1) - (beta * y));
728+
dx.device(d) = dy * ((beta * y) + temp2);
729+
}
730+
};
731+
703732
} // namespace operators
704733
} // namespace paddle
705734

@@ -730,4 +759,5 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
730759
__macro(elu, ELUFunctor, ELUGradFunctor); \
731760
__macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor); \
732761
__macro(hard_sigmoid, HardSigmoidFunctor, HardSigmoidGradFunctor); \
762+
__macro(swish, SwishFunctor, SwishGradFunctor); \
733763
__macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor);

python/paddle/v2/fluid/tests/test_activation_op.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
import numpy as np
33
from op_test import OpTest
4+
from scipy.special import expit
45

56

67
class TestExp(OpTest):
@@ -455,5 +456,20 @@ def test_check_grad(self):
455456
self.check_grad(['X'], 'Y', max_relative_error=0.002)
456457

457458

459+
class TestSwish(OpTest):
460+
def setUp(self):
461+
self.op_type = "swish"
462+
X = np.random.uniform(0.1, 1, [11, 17]).astype("float32")
463+
self.inputs = {'X': X}
464+
self.attrs = {'beta': 2.3}
465+
self.outputs = {'Y': X * expit(self.attrs['beta'] * X)}
466+
467+
def test_check_output(self):
468+
self.check_output()
469+
470+
def test_check_grad(self):
471+
self.check_grad(['X'], 'Y', max_relative_error=0.008)
472+
473+
458474
if __name__ == "__main__":
459475
unittest.main()

0 commit comments

Comments
 (0)