Skip to content

Commit 82866d4

Browse files
committed
Add register kernel functor and shrink reshape op
* Shrink reshape_op library size * User can register a standard C++ functor as a op kernel
1 parent 75ae426 commit 82866d4

File tree

4 files changed

+26
-35
lines changed

4 files changed

+26
-35
lines changed

paddle/fluid/framework/op_registry.h

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,18 +146,15 @@ struct OpKernelRegistrarFunctorEx<PlaceType, true, I,
146146
template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
147147
struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
148148
DataTypeAndKernelType...> {
149-
using KERNEL_TYPE =
149+
using Functor =
150150
typename std::tuple_element<I + 1,
151151
std::tuple<DataTypeAndKernelType...>>::type;
152152
using T =
153153
typename std::tuple_element<I,
154154
std::tuple<DataTypeAndKernelType...>>::type;
155155

156156
void operator()(const char* op_type, const char* library_type) const {
157-
RegisterKernelClass<PlaceType, T>(
158-
op_type, library_type, [](const framework::ExecutionContext& ctx) {
159-
KERNEL_TYPE().Compute(ctx);
160-
});
157+
RegisterKernelClass<PlaceType, T>(op_type, library_type, Functor());
161158

162159
constexpr auto size =
163160
std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
@@ -238,11 +235,11 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
238235
return 0; \
239236
}
240237

241-
#define REGISTER_OP_CUDA_KERNEL_EX(op_type, ...) \
242-
REGISTER_OP_KERNEL_EX(p_type, CUDA, ::paddle::platform::CUDAPlace, \
238+
#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \
239+
REGISTER_OP_KERNEL_EX(op_type, CUDA, ::paddle::platform::CUDAPlace, \
243240
__VA_ARGS__)
244241

245-
#define REGISTER_OP_CPU_KERNEL_EX(op_type, ...) \
242+
#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \
246243
REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
247244

248245
/**

paddle/fluid/operators/reshape_op.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
107107
}
108108
};
109109

110-
void ReshapeKernel::Compute(const framework::ExecutionContext &ctx) const {
110+
void ReshapeKernel::operator()(const framework::ExecutionContext &ctx) const {
111111
auto *out = ctx.Output<framework::LoDTensor>("Out");
112112
auto *in = ctx.Input<framework::LoDTensor>("X");
113113

@@ -147,7 +147,7 @@ void ReshapeKernel::Compute(const framework::ExecutionContext &ctx) const {
147147
out->Resize(out_dims);
148148
}
149149
}
150-
void ReshapeGradKernelBase::Compute(
150+
void ReshapeGradKernel::operator()(
151151
const framework::ExecutionContext &ctx) const {
152152
auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
153153
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
@@ -172,10 +172,10 @@ namespace ops = paddle::operators;
172172
REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
173173
paddle::framework::DefaultGradOpDescMaker<true>);
174174
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp);
175-
REGISTER_OP_CPU_KERNEL_EX(reshape, float, ops::ReshapeKernel, double,
176-
ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t,
177-
ops::ReshapeKernel);
178-
REGISTER_OP_CPU_KERNEL(reshape_grad, ops::ReshapeGradKernel<float>,
179-
ops::ReshapeGradKernel<double>,
180-
ops::ReshapeGradKernel<int>,
181-
ops::ReshapeGradKernel<int64_t>);
175+
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
176+
ops::ReshapeKernel, int, ops::ReshapeKernel,
177+
int64_t, ops::ReshapeKernel);
178+
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
179+
double, ops::ReshapeGradKernel, int,
180+
ops::ReshapeGradKernel, int64_t,
181+
ops::ReshapeGradKernel);

paddle/fluid/operators/reshape_op.cu.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/reshape_op.h"
1616
namespace ops = paddle::operators;
17-
REGISTER_OP_CUDA_KERNEL_EX(reshape, float, ops::ReshapeKernel, double,
18-
ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t,
19-
ops::ReshapeKernel);
20-
REGISTER_OP_CUDA_KERNEL(reshape_grad,
21-
paddle::operators::ReshapeGradKernel<float>,
22-
paddle::operators::ReshapeGradKernel<double>,
23-
paddle::operators::ReshapeGradKernel<int>,
24-
paddle::operators::ReshapeGradKernel<int64_t>);
17+
18+
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
19+
ops::ReshapeKernel, int, ops::ReshapeKernel,
20+
int64_t, ops::ReshapeKernel);
21+
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
22+
double, ops::ReshapeGradKernel, int,
23+
ops::ReshapeGradKernel, int64_t,
24+
ops::ReshapeGradKernel);

paddle/fluid/operators/reshape_op.h

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -118,21 +118,15 @@ class ReshapeOp : public framework::OperatorWithKernel {
118118
}
119119
};
120120

121-
class ReshapeKernel : public framework::OpKernelBase {
121+
class ReshapeKernel {
122122
public:
123-
void Compute(const framework::ExecutionContext &ctx) const final;
123+
void operator()(const framework::ExecutionContext &ctx) const;
124124
};
125125

126-
class ReshapeGradKernelBase : public framework::OpKernelBase {
126+
class ReshapeGradKernel {
127127
public:
128-
void Compute(const framework::ExecutionContext &ctx) const;
128+
void operator()(const framework::ExecutionContext &ctx) const;
129129
};
130130

131-
template <typename T>
132-
class ReshapeGradKernel : public ReshapeGradKernelBase {
133-
public:
134-
// Tell register element type.
135-
using ELEMENT_TYPE = T;
136-
};
137131
} // namespace operators
138132
} // namespace paddle

0 commit comments

Comments
 (0)