Skip to content

Commit b2f530e

Browse files
authored
Merge pull request #8532 from jacquesqiao/fix-compare-op
Fix compare op
2 parents e32ecc4 + 28d07e3 commit b2f530e

File tree

5 files changed

+42
-15
lines changed

5 files changed

+42
-15
lines changed

paddle/fluid/operators/compare_op.cc

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class CompareOp : public framework::OperatorWithKernel {
8383
} // namespace operators
8484
} // namespace paddle
8585

86-
#define REGISTER_LOGICAL_OP(op_type, _equation) \
86+
#define REGISTER_COMPARE_OP(op_type, _equation) \
8787
struct _##op_type##Comment { \
8888
static char type[]; \
8989
static char equation[]; \
@@ -96,11 +96,17 @@ class CompareOp : public framework::OperatorWithKernel {
9696
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
9797
::paddle::framework::EmptyGradOpMaker);
9898

99-
REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
100-
REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
101-
REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y");
102-
REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
103-
REGISTER_LOGICAL_OP(equal, "Out = X == Y");
104-
REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
105-
REGISTER_LOGICAL_OP(not_equal, "Out = X != Y");
106-
REGISTER_LOGICAL_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor);
99+
REGISTER_COMPARE_OP(less_than, "Out = X < Y");
100+
REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
101+
REGISTER_COMPARE_OP(less_equal, "Out = X <= Y");
102+
REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
103+
REGISTER_COMPARE_OP(greater_than, "Out = X > Y");
104+
REGISTER_COMPARE_KERNEL(greater_than, CPU,
105+
paddle::operators::GreaterThanFunctor);
106+
REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y");
107+
REGISTER_COMPARE_KERNEL(greater_equal, CPU,
108+
paddle::operators::GreaterEqualFunctor);
109+
REGISTER_COMPARE_OP(equal, "Out = X == Y");
110+
REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
111+
REGISTER_COMPARE_OP(not_equal, "Out = X != Y");
112+
REGISTER_COMPARE_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor);

paddle/fluid/operators/compare_op.cu

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/compare_op.h"
1616

17-
REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
18-
REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
19-
REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
20-
REGISTER_LOGICAL_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor);
17+
REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
18+
REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
19+
REGISTER_COMPARE_KERNEL(greater_than, CUDA,
20+
paddle::operators::GreaterThanFunctor);
21+
REGISTER_COMPARE_KERNEL(greater_equal, CUDA,
22+
paddle::operators::GreaterEqualFunctor);
23+
REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
24+
REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor);

paddle/fluid/operators/compare_op.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,18 @@ struct LessEqualFunctor {
3434
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; }
3535
};
3636

37+
template <typename T>
38+
struct GreaterThanFunctor {
39+
using ELEM_TYPE = T;
40+
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; }
41+
};
42+
43+
template <typename T>
44+
struct GreaterEqualFunctor {
45+
using ELEM_TYPE = T;
46+
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; }
47+
};
48+
3749
template <typename T>
3850
struct EqualFunctor {
3951
using ELEM_TYPE = T;
@@ -76,7 +88,7 @@ class CompareOpKernel
7688
} // namespace operators
7789
} // namespace paddle
7890

79-
#define REGISTER_LOGICAL_KERNEL(op_type, dev, functor) \
91+
#define REGISTER_COMPARE_KERNEL(op_type, dev, functor) \
8092
REGISTER_OP_##dev##_KERNEL( \
8193
op_type, ::paddle::operators::CompareOpKernel< \
8294
::paddle::platform::dev##DeviceContext, functor<int>>, \

python/paddle/v2/fluid/layers/math_op_patch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ def __impl__(self, other_var):
157157
("__eq__", "equal", False),
158158
("__ne__", "not_equal", False),
159159
("__lt__", "less_than", False),
160-
("__le__", "less_equal", False)):
160+
("__le__", "less_equal", False),
161+
("__gt__", "greater_than", False),
162+
("__ge__", "greater_equal", False)):
161163
setattr(Variable, method_name,
162164
_elemwise_method_creator_(method_name, op_type, reverse))
163165

python/paddle/v2/fluid/tests/unittests/test_compare_op.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ def test_output(self):
3838
for _type_name in {'float32', 'float64', 'int32', 'int64'}:
3939
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
4040
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
41+
create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b)
42+
create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b)
4143
create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
44+
create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b)
4245

4346
if __name__ == '__main__':
4447
unittest.main()

0 commit comments

Comments
 (0)