Skip to content

Commit 28d07e3

Browse files
committed
add python part of compare op
1 parent d4e3495 commit 28d07e3

File tree

5 files changed

+18
-12
lines changed

5 files changed

+18
-12
lines changed

paddle/fluid/operators/compare_op.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,12 @@ REGISTER_COMPARE_OP(less_than, "Out = X < Y");
100100
REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
101101
REGISTER_COMPARE_OP(less_equal, "Out = X <= Y");
102102
REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
103-
REGISTER_COMPARE_OP(larger_than, "Out = X > Y");
104-
REGISTER_COMPARE_KERNEL(larger_than, CPU, paddle::operators::LargerThanFunctor);
105-
REGISTER_COMPARE_OP(larger_equal, "Out = X >= Y");
106-
REGISTER_COMPARE_KERNEL(larger_equal, CPU,
107-
paddle::operators::LargerEqualFunctor);
103+
REGISTER_COMPARE_OP(greater_than, "Out = X > Y");
104+
REGISTER_COMPARE_KERNEL(greater_than, CPU,
105+
paddle::operators::GreaterThanFunctor);
106+
REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y");
107+
REGISTER_COMPARE_KERNEL(greater_equal, CPU,
108+
paddle::operators::GreaterEqualFunctor);
108109
REGISTER_COMPARE_OP(equal, "Out = X == Y");
109110
REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
110111
REGISTER_COMPARE_OP(not_equal, "Out = X != Y");

paddle/fluid/operators/compare_op.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ limitations under the License. */
1616

1717
REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
1818
REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
19-
REGISTER_COMPARE_KERNEL(larger_than, CUDA,
20-
paddle::operators::LargerThanFunctor);
21-
REGISTER_COMPARE_KERNEL(larger_equal, CUDA,
22-
paddle::operators::LargerEqualFunctor);
19+
REGISTER_COMPARE_KERNEL(greater_than, CUDA,
20+
paddle::operators::GreaterThanFunctor);
21+
REGISTER_COMPARE_KERNEL(greater_equal, CUDA,
22+
paddle::operators::GreaterEqualFunctor);
2323
REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
2424
REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor);

paddle/fluid/operators/compare_op.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ struct LessEqualFunctor {
3535
};
3636

3737
template <typename T>
38-
struct LargerThanFunctor {
38+
struct GreaterThanFunctor {
3939
using ELEM_TYPE = T;
4040
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; }
4141
};
4242

4343
template <typename T>
44-
struct LargerEqualFunctor {
44+
struct GreaterEqualFunctor {
4545
using ELEM_TYPE = T;
4646
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; }
4747
};

python/paddle/v2/fluid/layers/math_op_patch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ def __impl__(self, other_var):
157157
("__eq__", "equal", False),
158158
("__ne__", "not_equal", False),
159159
("__lt__", "less_than", False),
160-
("__le__", "less_equal", False)):
160+
("__le__", "less_equal", False),
161+
("__gt__", "greater_than", False),
162+
("__ge__", "greater_equal", False)):
161163
setattr(Variable, method_name,
162164
_elemwise_method_creator_(method_name, op_type, reverse))
163165

python/paddle/v2/fluid/tests/unittests/test_compare_op.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ def test_output(self):
3838
for _type_name in {'float32', 'float64', 'int32', 'int64'}:
3939
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
4040
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
41+
create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b)
42+
create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b)
4143
create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
44+
create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b)
4245

4346
if __name__ == '__main__':
4447
unittest.main()

0 commit comments

Comments
 (0)