Skip to content

Commit 1ce478f

Browse files
committed
Polish reshape op
1 parent 81f22bb commit 1ce478f

File tree

4 files changed

+157
-90
lines changed

4 files changed

+157
-90
lines changed

paddle/fluid/framework/op_registry.h

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,23 +76,27 @@ class OpRegistry {
7676
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
7777
struct OpKernelRegistrarFunctor;
7878

79+
template <typename PlaceType, typename T, typename KernelType>
80+
inline void RegisterKernelClass(const char* op_type, const char* library_type) {
81+
std::string library(library_type);
82+
std::string data_layout = "ANYLAYOUT";
83+
if (library == "MKLDNN") {
84+
data_layout = "MKLDNNLAYOUT";
85+
}
86+
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
87+
StringToDataLayout(data_layout),
88+
StringToLibraryType(library_type));
89+
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType());
90+
}
91+
7992
template <typename PlaceType, size_t I, typename... KernelTypes>
8093
struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
8194
using KERNEL_TYPE =
8295
typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
8396

8497
void operator()(const char* op_type, const char* library_type) const {
8598
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
86-
std::string library(library_type);
87-
std::string data_layout = "ANYLAYOUT";
88-
if (library == "MKLDNN") {
89-
data_layout = "MKLDNNLAYOUT";
90-
}
91-
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
92-
StringToDataLayout(data_layout),
93-
StringToLibraryType(library_type));
94-
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);
95-
99+
RegisterKernelClass<PlaceType, T, KERNEL_TYPE>(op_type, library_type);
96100
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
97101
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
98102
func;
@@ -116,6 +120,47 @@ class OpKernelRegistrar : public Registrar {
116120
}
117121
};
118122

123+
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
124+
struct OpKernelRegistrarFunctorEx;
125+
126+
template <typename PlaceType, typename... DataTypeAndKernelType>
127+
class OpKernelRegistrarEx : public Registrar {
128+
public:
129+
explicit OpKernelRegistrarEx(const char* op_type, const char* library_type) {
130+
OpKernelRegistrarFunctorEx<PlaceType, false, 0, DataTypeAndKernelType...>
131+
func;
132+
func(op_type, library_type);
133+
}
134+
};
135+
136+
template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
137+
struct OpKernelRegistrarFunctorEx<PlaceType, true, I,
138+
DataTypeAndKernelType...> {
139+
void operator()(const char* op_type, const char* library_type) const {}
140+
};
141+
142+
template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
143+
struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
144+
DataTypeAndKernelType...> {
145+
using KERNEL_TYPE =
146+
typename std::tuple_element<I + 1,
147+
std::tuple<DataTypeAndKernelType...>>::type;
148+
using T =
149+
typename std::tuple_element<I,
150+
std::tuple<DataTypeAndKernelType...>>::type;
151+
152+
void operator()(const char* op_type, const char* library_type) const {
153+
RegisterKernelClass<PlaceType, T, KERNEL_TYPE>(op_type, library_type);
154+
155+
constexpr auto size =
156+
std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
157+
OpKernelRegistrarFunctorEx<PlaceType, I + 2 >= size, I + 2,
158+
DataTypeAndKernelType...>
159+
func;
160+
func(op_type, library_type);
161+
}
162+
};
163+
119164
/**
120165
* check if MACRO is used in GLOBAL NAMESPACE.
121166
*/
@@ -174,6 +219,25 @@ class OpKernelRegistrar : public Registrar {
174219
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
175220
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
176221

222+
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, ...) \
223+
STATIC_ASSERT_GLOBAL_NAMESPACE( \
224+
__reg_op_kernel_##op_type##_##library_type##__, \
225+
"REGISTER_OP_KERNEL_EX must be called in global namespace"); \
226+
static ::paddle::framework::OpKernelRegistrarEx<place_class, __VA_ARGS__> \
227+
__op_kernel_registrar_##op_type##_##library_type##__(#op_type, \
228+
#library_type); \
229+
int TouchOpKernelRegistrar_##op_type##_##library_type() { \
230+
__op_kernel_registrar_##op_type##_##library_type##__.Touch(); \
231+
return 0; \
232+
}
233+
234+
#define REGISTER_OP_CUDA_KERNEL_EX(op_type, ...) \
235+
REGISTER_OP_KERNEL_EX(p_type, CUDA, ::paddle::platform::CUDAPlace, \
236+
__VA_ARGS__)
237+
238+
#define REGISTER_OP_CPU_KERNEL_EX(op_type, ...) \
239+
REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
240+
177241
/**
178242
* Macro to mark what Operator and Kernel
179243
* we will use and tell the compiler to

paddle/fluid/operators/reshape_op.cc

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,19 +107,75 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
107107
}
108108
};
109109

110+
void ReshapeKernel::Compute(const framework::ExecutionContext &ctx) const {
111+
auto *out = ctx.Output<framework::LoDTensor>("Out");
112+
auto *in = ctx.Input<framework::LoDTensor>("X");
113+
114+
auto *shape_tensor = ctx.HasInput("Shape")
115+
? ctx.Input<framework::LoDTensor>("Shape")
116+
: nullptr;
117+
118+
framework::DDim out_dims = out->dims();
119+
120+
if (shape_tensor) {
121+
auto *shape_data = shape_tensor->data<int>();
122+
framework::Tensor cpu_shape_tensor;
123+
if (platform::is_gpu_place(ctx.GetPlace())) {
124+
TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor);
125+
shape_data = cpu_shape_tensor.data<int>();
126+
}
127+
auto shape =
128+
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
129+
out_dims = ReshapeOp::ValidateShape(shape, in->dims());
130+
}
131+
if (!in->lod().empty()) {
132+
PADDLE_ENFORCE_EQ(out_dims[0], in->dims()[0],
133+
"Reshape operator cannot reshape an input sequence batch "
134+
"into an output sequence batch that has a different "
135+
"number of time steps. Please consider using "
136+
"sequence_reshape op.");
137+
}
138+
139+
bool inplace = ctx.Attr<bool>("inplace");
140+
out->Resize(out_dims);
141+
if (!inplace) {
142+
out->mutable_data(ctx.GetPlace(), in->type());
143+
framework::TensorCopySync(*in, ctx.GetPlace(), out);
144+
out->Resize(out_dims);
145+
} else {
146+
out->ShareDataWith(*in);
147+
out->Resize(out_dims);
148+
}
149+
}
150+
void ReshapeGradKernelBase::Compute(
151+
const framework::ExecutionContext &ctx) const {
152+
auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
153+
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
154+
155+
d_x->mutable_data(ctx.GetPlace(), d_out->type());
156+
bool inplace = ctx.Attr<bool>("inplace");
157+
158+
auto in_dims = d_x->dims();
159+
if (!inplace) {
160+
framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
161+
ctx.device_context().Wait();
162+
d_x->Resize(in_dims);
163+
} else {
164+
d_x->ShareDataWith(*d_out);
165+
d_x->Resize(in_dims);
166+
}
167+
}
110168
} // namespace operators
111169
} // namespace paddle
112170
namespace ops = paddle::operators;
113-
using CPU = paddle::platform::CPUDeviceContext;
114171

115172
REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
116173
paddle::framework::DefaultGradOpDescMaker<true>);
117174
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp);
118-
REGISTER_OP_CPU_KERNEL(reshape, ops::ReshapeKernel<CPU, float>,
119-
ops::ReshapeKernel<CPU, double>,
120-
ops::ReshapeKernel<CPU, int>,
121-
ops::ReshapeKernel<CPU, int64_t>);
122-
REGISTER_OP_CPU_KERNEL(reshape_grad, ops::ReshapeGradKernel<CPU, float>,
123-
ops::ReshapeGradKernel<CPU, double>,
124-
ops::ReshapeGradKernel<CPU, int>,
125-
ops::ReshapeGradKernel<CPU, int64_t>);
175+
REGISTER_OP_CPU_KERNEL_EX(reshape, float, ops::ReshapeKernel, double,
176+
ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t,
177+
ops::ReshapeKernel);
178+
REGISTER_OP_CPU_KERNEL(reshape_grad, ops::ReshapeGradKernel<float>,
179+
ops::ReshapeGradKernel<double>,
180+
ops::ReshapeGradKernel<int>,
181+
ops::ReshapeGradKernel<int64_t>);

paddle/fluid/operators/reshape_op.cu renamed to paddle/fluid/operators/reshape_op.cu.cc

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,12 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/reshape_op.h"
16-
using CUDA = paddle::platform::CUDADeviceContext;
17-
18-
REGISTER_OP_CUDA_KERNEL(reshape, paddle::operators::ReshapeKernel<CUDA, float>,
19-
paddle::operators::ReshapeKernel<CUDA, double>,
20-
paddle::operators::ReshapeKernel<CUDA, int>,
21-
paddle::operators::ReshapeKernel<CUDA, int64_t>);
16+
namespace ops = paddle::operators;
17+
REGISTER_OP_CUDA_KERNEL_EX(reshape, float, ops::ReshapeKernel, double,
18+
ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t,
19+
ops::ReshapeKernel);
2220
REGISTER_OP_CUDA_KERNEL(reshape_grad,
23-
paddle::operators::ReshapeGradKernel<CUDA, float>,
24-
paddle::operators::ReshapeGradKernel<CUDA, double>,
25-
paddle::operators::ReshapeGradKernel<CUDA, int>,
26-
paddle::operators::ReshapeGradKernel<CUDA, int64_t>);
21+
paddle::operators::ReshapeGradKernel<float>,
22+
paddle::operators::ReshapeGradKernel<double>,
23+
paddle::operators::ReshapeGradKernel<int>,
24+
paddle::operators::ReshapeGradKernel<int64_t>);

paddle/fluid/operators/reshape_op.h

Lines changed: 10 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -118,72 +118,21 @@ class ReshapeOp : public framework::OperatorWithKernel {
118118
}
119119
};
120120

121-
template <typename DeviceContext, typename T>
122-
class ReshapeKernel : public framework::OpKernel<T> {
121+
class ReshapeKernel : public framework::OpKernelBase {
123122
public:
124-
void Compute(const framework::ExecutionContext &ctx) const {
125-
auto *out = ctx.Output<framework::LoDTensor>("Out");
126-
auto *in = ctx.Input<framework::LoDTensor>("X");
127-
128-
auto *shape_tensor = ctx.HasInput("Shape")
129-
? ctx.Input<framework::LoDTensor>("Shape")
130-
: nullptr;
131-
132-
framework::DDim out_dims = out->dims();
133-
134-
if (shape_tensor) {
135-
auto *shape_data = shape_tensor->data<int>();
136-
framework::Tensor cpu_shape_tensor;
137-
if (platform::is_gpu_place(ctx.GetPlace())) {
138-
TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor);
139-
shape_data = cpu_shape_tensor.data<int>();
140-
}
141-
auto shape =
142-
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
143-
out_dims = ReshapeOp::ValidateShape(shape, in->dims());
144-
}
145-
if (!in->lod().empty()) {
146-
PADDLE_ENFORCE_EQ(
147-
out_dims[0], in->dims()[0],
148-
"Reshape operator cannot reshape an input sequence batch "
149-
"into an output sequence batch that has a different "
150-
"number of time steps. Please consider using "
151-
"sequence_reshape op.");
152-
}
123+
void Compute(const framework::ExecutionContext &ctx) const final;
124+
};
153125

154-
bool inplace = ctx.Attr<bool>("inplace");
155-
out->Resize(out_dims);
156-
if (!inplace) {
157-
out->mutable_data<T>(ctx.GetPlace());
158-
framework::TensorCopySync(*in, ctx.GetPlace(), out);
159-
out->Resize(out_dims);
160-
} else {
161-
out->ShareDataWith(*in);
162-
out->Resize(out_dims);
163-
}
164-
}
126+
class ReshapeGradKernelBase : public framework::OpKernelBase {
127+
public:
128+
void Compute(const framework::ExecutionContext &ctx) const;
165129
};
166130

167-
template <typename DeviceContext, typename T>
168-
class ReshapeGradKernel : public framework::OpKernel<T> {
131+
template <typename T>
132+
class ReshapeGradKernel : public ReshapeGradKernelBase {
169133
public:
170-
void Compute(const framework::ExecutionContext &ctx) const {
171-
auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
172-
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
173-
174-
d_x->mutable_data<T>(ctx.GetPlace());
175-
bool inplace = ctx.Attr<bool>("inplace");
176-
177-
auto in_dims = d_x->dims();
178-
if (!inplace) {
179-
framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
180-
ctx.device_context().Wait();
181-
d_x->Resize(in_dims);
182-
} else {
183-
d_x->ShareDataWith(*d_out);
184-
d_x->Resize(in_dims);
185-
}
186-
}
134+
// Tell register element type.
135+
using ELEMENT_TYPE = T;
187136
};
188137
} // namespace operators
189138
} // namespace paddle

0 commit comments

Comments
 (0)