Skip to content

Commit f74fb79

Browse files
authored
Compare Operator (#5325)
* Compare Operator * Follow comments
1 parent 58db07b commit f74fb79

File tree

8 files changed

+212
-2
lines changed

8 files changed

+212
-2
lines changed

paddle/framework/tensor_impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> {
5252
};
5353

5454
static inline size_t SizeOfType(std::type_index type) {
55-
SizeOfTypeFunctor<int, float, double, int16_t, int64_t> functor;
55+
SizeOfTypeFunctor<int, float, double, int16_t, int64_t, bool> functor;
5656
size_t size = functor(type);
5757
PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name());
5858
return size;

paddle/operators/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ function(op_library TARGET)
6262
file(APPEND ${pybind_file} "USE_OP(pool2d);\n")
6363
endif()
6464

65+
if ("${TARGET}" STREQUAL "compare_op")
66+
set(pybind_flag 1)
67+
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n")
68+
endif()
69+
6570
# pool_with_index_op contains several operators
6671
if ("${TARGET}" STREQUAL "pool_with_index_op")
6772
set(pybind_flag 1)

paddle/operators/compare_op.cc

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/compare_op.h"
16+
#include "paddle/framework/op_registry.h"
17+
namespace paddle {
18+
namespace operators {
19+
template <typename OpComment>
20+
class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
21+
public:
22+
CompareOpProtoMaker(framework::OpProto *proto,
23+
framework::OpAttrChecker *op_checker)
24+
: OpProtoAndCheckerMaker(proto, op_checker) {
25+
OpComment comment;
26+
AddInput("X",
27+
string::Sprintf("(LoDTensor) the left hand operand of %s operator",
28+
comment.type));
29+
AddInput("Y", string::Sprintf(
30+
"(LoDTensor) the right hand operand of %s operator",
31+
comment.type));
32+
AddOutput("Out", string::Sprintf(
33+
"(LoDTensor) n-dim bool tensor. Each element is %s",
34+
comment.equation));
35+
AddComment(string::Sprintf(R"DOC(%s Operator
36+
37+
It operates element-wise on X and Y, and returns the Out. Each of them is a
38+
N-dim tensor. X and Y could be any type. The each element of the Out tensor is
39+
calculated by %s
40+
)DOC",
41+
comment.type, comment.equation));
42+
}
43+
};
44+
45+
template <typename OpComment>
46+
class CompareOpInferShape : public framework::InferShapeBase {
47+
public:
48+
void operator()(framework::InferShapeContext *context) const override {
49+
OpComment comment;
50+
PADDLE_ENFORCE(context->HasInput("X"), "%s operator must has input X",
51+
comment.type);
52+
PADDLE_ENFORCE(context->HasInput("Y"), "%s operator must has input Y",
53+
comment.type);
54+
auto dim_x = context->GetInputDim("X");
55+
auto dim_y = context->GetInputDim("Y");
56+
PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y),
57+
"The number of elements in X and Y should be same");
58+
59+
context->SetOutputDim("Out", context->GetInputDim("X"));
60+
context->ShareLoD("X", "Out");
61+
}
62+
};
63+
64+
} // namespace operators
65+
} // namespace paddle
66+
67+
#define REGISTER_LOGICAL_OP(op_type, _equation) \
68+
struct _##op_type##Comment { \
69+
static char type[]; \
70+
static char equation[]; \
71+
}; \
72+
char _##op_type##Comment::type[]{#op_type}; \
73+
char _##op_type##Comment::equation[]{_equation}; \
74+
REGISTER_OP_WITH_KERNEL( \
75+
op_type, ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
76+
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
77+
::paddle::framework::EmptyGradOpMaker);
78+
79+
REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
80+
REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
81+
REGISTER_LOGICAL_OP(equal, "Out = X == Y");
82+
REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor);

paddle/operators/compare_op.cu

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/compare_op.h"
16+
17+
REGISTER_LOGICAL_KERNEL(less_than, GPU, paddle::operators::LessThanFunctor);
18+
REGISTER_LOGICAL_KERNEL(equal, GPU, paddle::operators::EqualFunctor);

paddle/operators/compare_op.h

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <math.h>
17+
#include <type_traits>
18+
#include "paddle/framework/op_registry.h"
19+
#include "paddle/platform/transform.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
template <typename T>
25+
struct LessThanFunctor {
26+
using ELEM_TYPE = T;
27+
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a < b; }
28+
};
29+
30+
template <typename T>
31+
struct EqualFunctor {
32+
using ELEM_TYPE = T;
33+
HOSTDEVICE bool operator()(const T& a, const T& b) const {
34+
if (std::is_floating_point<T>::value) {
35+
// This branch will be optimized while compiling if T is integer. It is
36+
// safe to cast a and b to double.
37+
return fabs(static_cast<double>(a - b)) < 1e-8;
38+
} else {
39+
return (a == b);
40+
}
41+
}
42+
};
43+
44+
template <typename Place, typename Functor>
45+
class CompareOpKernel
46+
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
47+
public:
48+
void Compute(const framework::ExecutionContext& context) const override {
49+
using T = typename Functor::ELEM_TYPE;
50+
auto* x = context.Input<framework::Tensor>("X");
51+
auto* y = context.Input<framework::Tensor>("Y");
52+
auto* out = context.Output<framework::Tensor>("Out");
53+
Functor binary_func;
54+
platform::Transform<Place> trans;
55+
trans(context.device_context(), x->data<T>(), x->data<T>() + x->numel(),
56+
y->data<T>(), out->mutable_data<bool>(context.GetPlace()),
57+
binary_func);
58+
}
59+
};
60+
61+
} // namespace operators
62+
} // namespace paddle
63+
64+
#define REGISTER_LOGICAL_KERNEL(op_type, dev, functor) \
65+
REGISTER_OP_##dev##_KERNEL( \
66+
op_type, \
67+
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \
68+
functor<int>>, \
69+
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \
70+
functor<int64_t>>, \
71+
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \
72+
functor<float>>, \
73+
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \
74+
functor<double>>);

paddle/pybind/pybind.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,13 @@ PYBIND11_PLUGIN(core) {
113113
.def("set", PyCPUTensorSetFromArray<int>)
114114
.def("set", PyCPUTensorSetFromArray<double>)
115115
.def("set", PyCPUTensorSetFromArray<int64_t>)
116+
.def("set", PyCPUTensorSetFromArray<bool>)
116117
#ifdef PADDLE_WITH_CUDA
117118
.def("set", PyCUDATensorSetFromArray<float>)
118119
.def("set", PyCUDATensorSetFromArray<int>)
119120
.def("set", PyCUDATensorSetFromArray<double>)
120121
.def("set", PyCUDATensorSetFromArray<int64_t>)
122+
.def("set", PyCUDATensorSetFromArray<bool>)
121123
#endif
122124
.def("shape", [](Tensor &self) { return vectorize(self.dims()); })
123125
.def("set_float_element", TensorSetElement<float>)

paddle/pybind/tensor_py.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
8585
} // namespace details
8686
inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) {
8787
auto buffer_info =
88-
details::CastToPyBufferImpl<true, 0, float, int, double, int64_t>()(
88+
details::CastToPyBufferImpl<true, 0, float, int, double, int64_t, bool>()(
8989
tensor);
9090
return buffer_info;
9191
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import op_test
2+
import unittest
3+
import numpy
4+
5+
6+
def create_test_class(op_type, typename, callback):
7+
class Cls(op_test.OpTest):
8+
def setUp(self):
9+
a = numpy.random.random(size=(10, 7)).astype(typename)
10+
b = numpy.random.random(size=(10, 7)).astype(typename)
11+
c = callback(a, b)
12+
self.inputs = {'X': a, 'Y': b}
13+
self.outputs = {'Out': c}
14+
self.op_type = op_type
15+
16+
def test_output(self):
17+
self.check_output()
18+
19+
cls_name = "{0}_{1}".format(op_type, typename)
20+
Cls.__name__ = cls_name
21+
globals()[cls_name] = Cls
22+
23+
24+
for _type_name in {'float32', 'float64', 'int32', 'int64'}:
25+
create_test_class('less_than', _type_name, lambda _a, _b: _a < _b)
26+
create_test_class('equal', _type_name, lambda _a, _b: _a == _b)
27+
28+
if __name__ == '__main__':
29+
unittest.main()

0 commit comments

Comments
 (0)