Skip to content

Commit 54a4696

Browse files
authored
Merge pull request #7660 from reyoung/feature/compare_op_use_elemwise
Make compare_op reuse elemwise_op_funcs
2 parents 430fdc5 + 2024489 commit 54a4696

File tree

5 files changed

+15
-38
lines changed

5 files changed

+15
-38
lines changed

paddle/operators/compare_op.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ N-dim tensor. X and Y could be any type. The each element of the Out tensor is
3939
calculated by %s
4040
)DOC",
4141
comment.type, comment.equation));
42+
AddAttr<int>("axis",
43+
"(int, default -1). The start dimension index "
44+
"for broadcasting Y onto X.")
45+
.SetDefault(-1)
46+
.EqualGreaterThan(-1);
4247
}
4348
};
4449

@@ -95,11 +100,5 @@ REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
95100
REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
96101
REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y");
97102
REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
98-
REGISTER_LOGICAL_OP(greater_than, "Out = X > Y");
99-
REGISTER_LOGICAL_KERNEL(greater_than, CPU,
100-
paddle::operators::GreaterThanFunctor);
101-
REGISTER_LOGICAL_OP(greater_equal, "Out = X >= Y");
102-
REGISTER_LOGICAL_KERNEL(greater_equal, CPU,
103-
paddle::operators::GreaterEqualFunctor);
104103
REGISTER_LOGICAL_OP(equal, "Out = X == Y");
105104
REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor);

paddle/operators/compare_op.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,4 @@ limitations under the License. */
1616

1717
REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
1818
REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
19-
REGISTER_LOGICAL_KERNEL(greater_than, CUDA,
20-
paddle::operators::GreaterThanFunctor);
21-
REGISTER_LOGICAL_KERNEL(greater_equal, CUDA,
22-
paddle::operators::GreaterEqualFunctor);
2319
REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);

paddle/operators/compare_op.h

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616
#include <math.h>
1717
#include <type_traits>
1818
#include "paddle/framework/op_registry.h"
19+
#include "paddle/operators/elementwise_op_function.h"
1920
#include "paddle/platform/transform.h"
2021

2122
namespace paddle {
@@ -33,18 +34,6 @@ struct LessEqualFunctor {
3334
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; }
3435
};
3536

36-
template <typename T>
37-
struct GreaterThanFunctor {
38-
using ELEM_TYPE = T;
39-
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; }
40-
};
41-
42-
template <typename T>
43-
struct GreaterEqualFunctor {
44-
using ELEM_TYPE = T;
45-
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; }
46-
};
47-
4837
template <typename T>
4938
struct EqualFunctor {
5039
using ELEM_TYPE = T;
@@ -65,14 +54,7 @@ class CompareOpKernel
6554
public:
6655
void Compute(const framework::ExecutionContext& context) const override {
6756
using T = typename Functor::ELEM_TYPE;
68-
auto* x = context.Input<framework::Tensor>("X");
69-
auto* y = context.Input<framework::Tensor>("Y");
70-
auto* out = context.Output<framework::Tensor>("Out");
71-
Functor binary_func;
72-
platform::Transform<DeviceContext> trans;
73-
trans(context.template device_context<DeviceContext>(), x->data<T>(),
74-
x->data<T>() + x->numel(), y->data<T>(),
75-
out->mutable_data<bool>(context.GetPlace()), binary_func);
57+
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context);
7658
}
7759
};
7860

paddle/operators/elementwise_op_function.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,15 @@ class MidWiseTransformIterator<T, platform::CUDADeviceContext>
176176
};
177177
#endif
178178

179-
template <typename Functor, typename T, typename DeviceContext>
179+
template <typename Functor, typename T, typename DeviceContext,
180+
typename OutType = T>
180181
class TransformFunctor {
181182
public:
182183
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
183184
framework::Tensor* z, const DeviceContext& ctx, Functor func)
184185
: x_(x->data<T>()),
185186
y_(y->data<T>()),
186-
z_(z->mutable_data<T>(ctx.GetPlace())),
187+
z_(z->mutable_data<OutType>(ctx.GetPlace())),
187188
nx_(x->numel()),
188189
ctx_(ctx),
189190
func_(func) {}
@@ -208,7 +209,7 @@ class TransformFunctor {
208209
private:
209210
const T* x_;
210211
const T* y_;
211-
T* z_;
212+
OutType* z_;
212213
int64_t nx_;
213214
const DeviceContext& ctx_;
214215
Functor func_;
@@ -364,15 +365,16 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
364365
}
365366
}
366367

367-
template <typename Functor, typename DeviceContext, typename T>
368+
template <typename Functor, typename DeviceContext, typename T,
369+
typename OutType = T>
368370
void ElementwiseComputeEx(const framework::ExecutionContext& ctx) {
369371
using Tensor = framework::Tensor;
370372

371373
auto* x = ctx.Input<Tensor>("X");
372374
auto* y = ctx.Input<Tensor>("Y");
373375
auto* z = ctx.Output<Tensor>("Out");
374-
z->mutable_data<T>(ctx.GetPlace());
375-
TransformFunctor<Functor, T, DeviceContext> functor(
376+
z->mutable_data<OutType>(ctx.GetPlace());
377+
TransformFunctor<Functor, T, DeviceContext, OutType> functor(
376378
x, y, z, ctx.template device_context<DeviceContext>(), Functor());
377379

378380
auto x_dims = x->dims();

python/paddle/v2/fluid/tests/test_compare_op.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ def test_output(self):
3838
for _type_name in {'float32', 'float64', 'int32', 'int64'}:
3939
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
4040
create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b)
41-
create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b)
42-
create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b)
4341
create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
4442

4543
if __name__ == '__main__':

0 commit comments

Comments
 (0)