Skip to content

Commit 7f22a6d

Browse files
authored
Merge pull request #5465 from reyoung/feature/compare_op_support_cpu
CompareOp's kernel device type is decided by input tensor place
2 parents 2a76b42 + 3187451 commit 7f22a6d

File tree

2 files changed

+26
-14
lines changed

2 files changed

+26
-14
lines changed

paddle/operators/compare_op.cc

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "paddle/operators/compare_op.h"
1616
#include "paddle/framework/op_registry.h"
17+
1718
namespace paddle {
1819
namespace operators {
1920
template <typename OpComment>
@@ -61,19 +62,34 @@ class CompareOpInferShape : public framework::InferShapeBase {
6162
}
6263
};
6364

65+
class CompareOp : public framework::OperatorWithKernel {
66+
public:
67+
using framework::OperatorWithKernel::OperatorWithKernel;
68+
69+
protected:
70+
framework::OpKernelType GetKernelType(
71+
const framework::ExecutionContext &ctx) const override {
72+
framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
73+
// CompareOp kernel's device type is decided by input tensor place
74+
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
75+
return kt;
76+
}
77+
};
78+
6479
} // namespace operators
6580
} // namespace paddle
6681

67-
#define REGISTER_LOGICAL_OP(op_type, _equation) \
68-
struct _##op_type##Comment { \
69-
static char type[]; \
70-
static char equation[]; \
71-
}; \
72-
char _##op_type##Comment::type[]{#op_type}; \
73-
char _##op_type##Comment::equation[]{_equation}; \
74-
REGISTER_OP_WITH_KERNEL( \
75-
op_type, ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
76-
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
82+
#define REGISTER_LOGICAL_OP(op_type, _equation) \
83+
struct _##op_type##Comment { \
84+
static char type[]; \
85+
static char equation[]; \
86+
}; \
87+
char _##op_type##Comment::type[]{#op_type}; \
88+
char _##op_type##Comment::equation[]{_equation}; \
89+
REGISTER_OPERATOR( \
90+
op_type, ::paddle::operators::CompareOp, \
91+
::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
92+
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
7793
::paddle::framework::EmptyGradOpMaker);
7894

7995
REGISTER_LOGICAL_OP(less_than, "Out = X < Y");

paddle/platform/transform.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,6 @@ struct Transform<platform::CPUPlace> {
4949
template <typename InputIter, typename OutputIter, typename UnaryOperation>
5050
void operator()(const DeviceContext& context, InputIter first, InputIter last,
5151
OutputIter result, UnaryOperation op) {
52-
auto place = context.GetPlace();
53-
PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place.");
5452
std::transform(first, last, result, op);
5553
}
5654

@@ -59,8 +57,6 @@ struct Transform<platform::CPUPlace> {
5957
void operator()(const DeviceContext& context, InputIter1 first1,
6058
InputIter1 last1, InputIter2 first2, OutputIter result,
6159
BinaryOperation op) {
62-
auto place = context.GetPlace();
63-
PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place.");
6460
std::transform(first1, last1, first2, result, op);
6561
}
6662
};

0 commit comments

Comments
 (0)