|
14 | 14 |
|
15 | 15 | #include "paddle/operators/compare_op.h"
|
16 | 16 | #include "paddle/framework/op_registry.h"
|
| 17 | + |
17 | 18 | namespace paddle {
|
18 | 19 | namespace operators {
|
19 | 20 | template <typename OpComment>
|
@@ -61,19 +62,34 @@ class CompareOpInferShape : public framework::InferShapeBase {
|
61 | 62 | }
|
62 | 63 | };
|
63 | 64 |
|
| 65 | +class CompareOp : public framework::OperatorWithKernel { |
| 66 | + public: |
| 67 | + using framework::OperatorWithKernel::OperatorWithKernel; |
| 68 | + |
| 69 | + protected: |
| 70 | + framework::OpKernelType GetKernelType( |
| 71 | + const framework::ExecutionContext &ctx) const override { |
| 72 | + framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx); |
| 73 | + // CompareOp kernel's device type is decided by input tensor place |
| 74 | + kt.place_ = ctx.Input<framework::LoDTensor>("X")->place(); |
| 75 | + return kt; |
| 76 | + } |
| 77 | +}; |
| 78 | + |
64 | 79 | } // namespace operators
|
65 | 80 | } // namespace paddle
|
66 | 81 |
|
67 |
| -#define REGISTER_LOGICAL_OP(op_type, _equation) \ |
68 |
| - struct _##op_type##Comment { \ |
69 |
| - static char type[]; \ |
70 |
| - static char equation[]; \ |
71 |
| - }; \ |
72 |
| - char _##op_type##Comment::type[]{#op_type}; \ |
73 |
| - char _##op_type##Comment::equation[]{_equation}; \ |
74 |
| - REGISTER_OP_WITH_KERNEL( \ |
75 |
| - op_type, ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \ |
76 |
| - ::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \ |
| 82 | +#define REGISTER_LOGICAL_OP(op_type, _equation) \ |
| 83 | + struct _##op_type##Comment { \ |
| 84 | + static char type[]; \ |
| 85 | + static char equation[]; \ |
| 86 | + }; \ |
| 87 | + char _##op_type##Comment::type[]{#op_type}; \ |
| 88 | + char _##op_type##Comment::equation[]{_equation}; \ |
| 89 | + REGISTER_OPERATOR( \ |
| 90 | + op_type, ::paddle::operators::CompareOp, \ |
| 91 | + ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \ |
| 92 | + ::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \ |
77 | 93 | ::paddle::framework::EmptyGradOpMaker);
|
78 | 94 |
|
79 | 95 | REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
|
|
0 commit comments