Skip to content

Commit 513b1e0

Browse files
authored
"add floor, ceil, round op" (#5898)
* "add floor, ceil, round op" * "reuse zero gradient" * "fix divide zero" * "fix numpy floor error"
1 parent 45062fe commit 513b1e0

File tree

3 files changed

+135
-0
lines changed

3 files changed

+135
-0
lines changed

paddle/operators/activation_op.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,51 @@ Abs Activation Operator.
223223
}
224224
};
225225

226+
class CeilOpMaker : public framework::OpProtoAndCheckerMaker {
227+
public:
228+
CeilOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
229+
: OpProtoAndCheckerMaker(proto, op_checker) {
230+
AddInput("X", "Input of Ceil operator");
231+
AddOutput("Y", "Output of Ceil operator");
232+
AddComment(R"DOC(
233+
Ceil Activation Operator.
234+
235+
$y = ceil(x)$
236+
237+
)DOC");
238+
}
239+
};
240+
241+
class FloorOpMaker : public framework::OpProtoAndCheckerMaker {
242+
public:
243+
FloorOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
244+
: OpProtoAndCheckerMaker(proto, op_checker) {
245+
AddInput("X", "Input of Floor operator");
246+
AddOutput("Y", "Output of Floor operator");
247+
AddComment(R"DOC(
248+
Floor Activation Operator.
249+
250+
$y = floor(x)$
251+
252+
)DOC");
253+
}
254+
};
255+
256+
class RoundOpMaker : public framework::OpProtoAndCheckerMaker {
257+
public:
258+
RoundOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
259+
: OpProtoAndCheckerMaker(proto, op_checker) {
260+
AddInput("X", "Input of Round operator");
261+
AddOutput("Y", "Output of Round operator");
262+
AddComment(R"DOC(
263+
Round Activation Operator.
264+
265+
$y = [x]$
266+
267+
)DOC");
268+
}
269+
};
270+
226271
class ReciprocalOpMaker : public framework::OpProtoAndCheckerMaker {
227272
public:
228273
ReciprocalOpMaker(framework::OpProto *proto,
@@ -493,6 +538,15 @@ REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad,
493538
REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad,
494539
ops::ActivationOpGrad);
495540

541+
REGISTER_OP(ceil, ops::ActivationOp, ops::CeilOpMaker, ceil_grad,
542+
ops::ActivationOpGrad);
543+
544+
REGISTER_OP(floor, ops::ActivationOp, ops::FloorOpMaker, floor_grad,
545+
ops::ActivationOpGrad);
546+
547+
REGISTER_OP(round, ops::ActivationOp, ops::RoundOpMaker, round_grad,
548+
ops::ActivationOpGrad);
549+
496550
REGISTER_OP(reciprocal, ops::ActivationOp, ops::ReciprocalOpMaker,
497551
reciprocal_grad, ops::ActivationOpGrad);
498552

paddle/operators/activation_op.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,41 @@ struct SqrtGradFunctor : public BaseActivationFunctor<T> {
283283
}
284284
};
285285

286+
// ceil(x) = ceiling(x)
287+
template <typename T>
288+
struct CeilFunctor : public BaseActivationFunctor<T> {
289+
template <typename Device, typename X, typename Y>
290+
void operator()(Device d, X x, Y y) const {
291+
y.device(d) = x.ceil();
292+
}
293+
};
294+
295+
template <typename T>
296+
struct ZeroGradFunctor : public BaseActivationFunctor<T> {
297+
template <typename Device, typename X, typename Y, typename dY, typename dX>
298+
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
299+
dx.device(d) = static_cast<T>(0) / x;
300+
}
301+
};
302+
303+
// floor(x) = flooring(x)
304+
template <typename T>
305+
struct FloorFunctor : public BaseActivationFunctor<T> {
306+
template <typename Device, typename X, typename Y>
307+
void operator()(Device d, X x, Y y) const {
308+
y.device(d) = x.ceil();
309+
}
310+
};
311+
312+
// round(x) = [x]
313+
template <typename T>
314+
struct RoundFunctor : public BaseActivationFunctor<T> {
315+
template <typename Device, typename X, typename Y>
316+
void operator()(Device d, X x, Y y) const {
317+
y.device(d) = x.round();
318+
}
319+
};
320+
286321
// abs(x) = |x|
287322
template <typename T>
288323
struct AbsFunctor : public BaseActivationFunctor<T> {
@@ -677,6 +712,9 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
677712
__macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
678713
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
679714
__macro(abs, AbsFunctor, AbsGradFunctor); \
715+
__macro(ceil, CeilFunctor, ZeroGradFunctor); \
716+
__macro(floor, FloorFunctor, ZeroGradFunctor); \
717+
__macro(round, RoundFunctor, ZeroGradFunctor); \
680718
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
681719
__macro(log, LogFunctor, LogGradFunctor); \
682720
__macro(square, SquareFunctor, SquareGradFunctor); \

python/paddle/v2/fluid/tests/test_activation_op.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,49 @@ def test_check_grad(self):
152152
self.check_grad(['X'], 'Y', max_relative_error=0.007)
153153

154154

155+
class TestCeil(OpTest):
156+
def setUp(self):
157+
self.op_type = "ceil"
158+
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
159+
self.inputs = {'X': x}
160+
self.outputs = {'Y': np.ceil(self.inputs['X'])}
161+
162+
def test_check_output(self):
163+
self.check_output()
164+
165+
def test_check_grad(self):
166+
self.check_grad(['X'], 'Y', max_relative_error=0.007)
167+
168+
169+
class TestFloor(OpTest):
170+
def setUp(self):
171+
self.op_type = "floor"
172+
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
173+
self.inputs = {'X': x}
174+
# numpy floor need +1
175+
self.outputs = {'Y': np.floor(self.inputs['X']) + 1.0}
176+
177+
def test_check_output(self):
178+
self.check_output()
179+
180+
def test_check_grad(self):
181+
self.check_grad(['X'], 'Y', max_relative_error=0.007)
182+
183+
184+
class TestRound(OpTest):
185+
def setUp(self):
186+
self.op_type = "round"
187+
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
188+
self.inputs = {'X': x}
189+
self.outputs = {'Y': np.round(self.inputs['X'])}
190+
191+
def test_check_output(self):
192+
self.check_output()
193+
194+
def test_check_grad(self):
195+
self.check_grad(['X'], 'Y', max_relative_error=0.007)
196+
197+
155198
class TestRelu(OpTest):
156199
def setUp(self):
157200
self.op_type = "relu"

0 commit comments

Comments
 (0)