Skip to content

Commit d4e3495

Browse files
committed
add larger_than and larger_equal op and kernel
1 parent bad0159 commit d4e3495

File tree

3 files changed

+21
-0
lines changed

3 files changed

+21
-0
lines changed

paddle/fluid/operators/compare_op.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ REGISTER_COMPARE_OP(less_than, "Out = X < Y");
100100
REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
101101
REGISTER_COMPARE_OP(less_equal, "Out = X <= Y");
102102
REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
103+
REGISTER_COMPARE_OP(larger_than, "Out = X > Y");
104+
REGISTER_COMPARE_KERNEL(larger_than, CPU, paddle::operators::LargerThanFunctor);
105+
REGISTER_COMPARE_OP(larger_equal, "Out = X >= Y");
106+
REGISTER_COMPARE_KERNEL(larger_equal, CPU,
107+
paddle::operators::LargerEqualFunctor);
103108
REGISTER_COMPARE_OP(equal, "Out = X == Y");
104109
REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
105110
REGISTER_COMPARE_OP(not_equal, "Out = X != Y");

paddle/fluid/operators/compare_op.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,9 @@ limitations under the License. */
1616

1717
REGISTER_COMPARE_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
1818
REGISTER_COMPARE_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
19+
REGISTER_COMPARE_KERNEL(larger_than, CUDA,
20+
paddle::operators::LargerThanFunctor);
21+
REGISTER_COMPARE_KERNEL(larger_equal, CUDA,
22+
paddle::operators::LargerEqualFunctor);
1923
REGISTER_COMPARE_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
2024
REGISTER_COMPARE_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor);

paddle/fluid/operators/compare_op.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,18 @@ struct LessEqualFunctor {
3434
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; }
3535
};
3636

37+
template <typename T>
38+
struct LargerThanFunctor {
39+
using ELEM_TYPE = T;
40+
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; }
41+
};
42+
43+
template <typename T>
44+
struct LargerEqualFunctor {
45+
using ELEM_TYPE = T;
46+
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; }
47+
};
48+
3749
template <typename T>
3850
struct EqualFunctor {
3951
using ELEM_TYPE = T;

0 commit comments

Comments
 (0)