Skip to content

Commit b504a23

Browse files
authored
Adding the Thresholded Relu Op (#4685)
* Adding thresholded_relu op * Adding test for thresholded relu op
1 parent 2603cb7 commit b504a23

File tree

3 files changed

+70
-1
lines changed

3 files changed

+70
-1
lines changed

paddle/operators/activation_op.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,23 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
321321
}
322322
};
323323

324+
template <typename AttrType>
325+
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
326+
public:
327+
ThresholdedReluOpMaker(framework::OpProto *proto,
328+
framework::OpAttrChecker *op_checker)
329+
: OpProtoAndCheckerMaker(proto, op_checker) {
330+
AddInput("X", "Input of ThresholdedRelu operator");
331+
AddOutput("Y", "Output of ThresholdedRelu operator");
332+
AddComment(
333+
"ThresholdedRelu activation operator, "
334+
"thresholded_relu = x for x > threshold, "
335+
"thresholded_relu = 0 otherwise.");
336+
AddAttr<AttrType>("threshold", "The threshold location of activation")
337+
.SetDefault(static_cast<AttrType>(1.0));
338+
}
339+
};
340+
324341
} // namespace operators
325342
} // namespace paddle
326343

@@ -392,6 +409,10 @@ REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker<float>, stanh_grad,
392409
REGISTER_OP(hard_shrink, ops::ActivationOp, ops::HardShrinkOpMaker<float>,
393410
hard_shrink_grad, ops::ActivationOpGrad);
394411

412+
REGISTER_OP(thresholded_relu, ops::ActivationOp,
413+
ops::ThresholdedReluOpMaker<float>, thresholded_relu_grad,
414+
ops::ActivationOpGrad);
415+
395416
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
396417
REGISTER_OP_CPU_KERNEL( \
397418
act_type, \

paddle/operators/activation_op.h

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,32 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
590590
}
591591
};
592592

593+
template <typename T>
594+
struct ThresholdedReluFunctor : public BaseActivationFunctor<T> {
595+
float threshold;
596+
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
597+
return {{"threshold", &threshold}};
598+
}
599+
600+
template <typename Device, typename X, typename Y>
601+
void operator()(Device d, X x, Y y) const {
602+
y.device(d) = (x > static_cast<T>(threshold)).template cast<T>() * x;
603+
}
604+
};
605+
606+
template <typename T>
607+
struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
608+
float threshold;
609+
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
610+
return {{"threshold", &threshold}};
611+
}
612+
613+
template <typename Device, typename X, typename Y, typename dY, typename dX>
614+
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
615+
dx.device(d) = dy * (x > static_cast<T>(threshold)).template cast<T>();
616+
}
617+
};
618+
593619
} // namespace operators
594620
} // namespace paddle
595621

@@ -615,4 +641,5 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
615641
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
616642
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
617643
__macro(elu, ELUFunctor, ELUGradFunctor); \
618-
__macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor)
644+
__macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor); \
645+
__macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor);

python/paddle/v2/framework/tests/test_activation_op.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,5 +363,26 @@ def test_check_grad(self):
363363
self.check_grad(['X'], 'Y', max_relative_error=0.007)
364364

365365

366+
class TestThresholdedRelu(OpTest):
367+
def setUp(self):
368+
self.op_type = "thresholded_relu"
369+
threshold = 0.25
370+
self.relative_error = 0.005
371+
X = np.random.uniform(-1, 1, [11, 17]).astype("float32")
372+
373+
# Same reason as TestAbs
374+
X[np.abs(X - threshold) < self.relative_error] = threshold + 0.2
375+
376+
self.inputs = {'X': X}
377+
self.attrs = {'threshold': threshold}
378+
self.outputs = {'Y': (X > threshold) * X}
379+
380+
def test_check_output(self):
381+
self.check_output()
382+
383+
def test_check_grad(self):
384+
self.check_grad(['X'], 'Y', max_relative_error=self.relative_error)
385+
386+
366387
if __name__ == "__main__":
367388
unittest.main()

0 commit comments

Comments
 (0)