Skip to content

Commit c3a6134

Browse files
authored
Adding greater than and less than equal ops to compare op (#5609)
* Adding greater than and less than equal ops to compare op * Changing the name of the less_than_equal and greater_than_equal op * Also changing the name of the functors
1 parent 562599e commit c3a6134

File tree

4 files changed

+34
-0
lines changed

4 files changed

+34
-0
lines changed

paddle/operators/compare_op.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,5 +94,13 @@ class CompareOp : public framework::OperatorWithKernel {
9494

9595
REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
9696
REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
97+
REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y");
98+
REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
99+
REGISTER_LOGICAL_OP(greater_than, "Out = X > Y");
100+
REGISTER_LOGICAL_KERNEL(greater_than, CPU,
101+
paddle::operators::GreaterThanFunctor);
102+
REGISTER_LOGICAL_OP(greater_equal, "Out = X >= Y");
103+
REGISTER_LOGICAL_KERNEL(greater_equal, CPU,
104+
paddle::operators::GreaterEqualFunctor);
97105
REGISTER_LOGICAL_OP(equal, "Out = X == Y");
98106
REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor);

paddle/operators/compare_op.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,9 @@
1515
#include "paddle/operators/compare_op.h"
1616

1717
REGISTER_LOGICAL_KERNEL(less_than, GPU, paddle::operators::LessThanFunctor);
18+
REGISTER_LOGICAL_KERNEL(less_equal, GPU, paddle::operators::LessEqualFunctor);
19+
REGISTER_LOGICAL_KERNEL(greater_than, GPU,
20+
paddle::operators::GreaterThanFunctor);
21+
REGISTER_LOGICAL_KERNEL(greater_equal, GPU,
22+
paddle::operators::GreaterEqualFunctor);
1823
REGISTER_LOGICAL_KERNEL(equal, GPU, paddle::operators::EqualFunctor);

paddle/operators/compare_op.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,24 @@ struct LessThanFunctor {
2727
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a < b; }
2828
};
2929

30+
template <typename T>
31+
struct LessEqualFunctor {
32+
using ELEM_TYPE = T;
33+
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; }
34+
};
35+
36+
template <typename T>
37+
struct GreaterThanFunctor {
38+
using ELEM_TYPE = T;
39+
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; }
40+
};
41+
42+
template <typename T>
43+
struct GreaterEqualFunctor {
44+
using ELEM_TYPE = T;
45+
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; }
46+
};
47+
3048
template <typename T>
3149
struct EqualFunctor {
3250
using ELEM_TYPE = T;

python/paddle/v2/fluid/tests/test_compare_op.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ def test_output(self):
2323

2424
for _type_name in {'float32', 'float64', 'int32', 'int64'}:
2525
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
26+
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
27+
create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b)
28+
create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b)
2629
create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
2730

2831
if __name__ == '__main__':

0 commit comments

Comments
 (0)