Skip to content

Commit e2e82bd

Browse files
committed
Accelerate Reshape op
1 parent e1904ac commit e2e82bd

File tree

2 files changed

+51
-36
lines changed

2 files changed

+51
-36
lines changed

paddle/fluid/operators/reshape_op.cc

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ dimension value will be copied from Input(X) at runtime. Note that the index of
164164
[2, 3, 4], Attr(shape) = [2, 3, 2, 0] is an invalid input.
165165
166166
3. Input(Shape) has a higher priority than Attr(shape) if it is provided, while
167-
Attr(shape) still should be set correctly to gurantee shape inference in
167+
Attr(shape) still should be set correctly to gurantee shape inference in
168168
compile-time.
169169
170170
)DOC");
@@ -195,6 +195,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
195195
}
196196
};
197197

198+
template <typename T>
198199
class ReshapeKernel {
199200
public:
200201
void operator()(const framework::ExecutionContext &ctx) const {
@@ -227,21 +228,25 @@ class ReshapeKernel {
227228
"sequence_reshape op.");
228229
}
229230

230-
out->mutable_data(ctx.GetPlace(), in->type());
231-
framework::TensorCopySync(*in, ctx.GetPlace(), out);
231+
if (in->data<T>() !=
232+
reinterpret_cast<T *>(out->mutable_data(ctx.GetPlace(), in->type()))) {
233+
framework::TensorCopySync(*in, ctx.GetPlace(), out);
234+
}
232235
out->Resize(out_dims);
233236
}
234237
};
235238

239+
template <typename T>
236240
class ReshapeGradKernel {
237241
public:
238242
void operator()(const framework::ExecutionContext &ctx) const {
239243
auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
240244
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
241245
auto in_dims = d_x->dims();
242246

243-
d_x->mutable_data(ctx.GetPlace(), d_out->type());
244-
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
247+
if (d_out->data<T>() != d_x->mutable_data(ctx.GetPlace(), d_out->type())) {
248+
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
249+
}
245250
d_x->Resize(in_dims);
246251
}
247252
};
@@ -259,7 +264,6 @@ class Reshape2Op : public ReshapeOp {
259264
: ReshapeOp(type, inputs, outputs, attrs) {}
260265

261266
void InferShape(framework::InferShapeContext *ctx) const override {
262-
ReshapeOp::InferShape(ctx);
263267
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
264268
"Output(XShape) of ReshapeOp should not be null.");
265269
const auto &x_dims = ctx->GetInputDim("X");
@@ -270,6 +274,8 @@ class Reshape2Op : public ReshapeOp {
270274
}
271275
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
272276
ctx->ShareLoD("X", /*->*/ "XShape");
277+
278+
ReshapeOp::InferShape(ctx);
273279
}
274280
};
275281

@@ -335,38 +341,46 @@ namespace ops = paddle::operators;
335341
REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
336342
paddle::framework::DefaultGradOpDescMaker<true>);
337343
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp);
338-
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
339-
ops::ReshapeKernel, int, ops::ReshapeKernel,
340-
int64_t, ops::ReshapeKernel);
341-
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
342-
double, ops::ReshapeGradKernel, int,
343-
ops::ReshapeGradKernel, int64_t,
344-
ops::ReshapeGradKernel);
344+
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel<float>,
345+
double, ops::ReshapeKernel<double>, int,
346+
ops::ReshapeKernel<int>, int64_t,
347+
ops::ReshapeKernel<int64_t>);
348+
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float,
349+
ops::ReshapeGradKernel<float>, double,
350+
ops::ReshapeGradKernel<double>, int,
351+
ops::ReshapeGradKernel<int>, int64_t,
352+
ops::ReshapeGradKernel<int64_t>);
345353

346354
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
347355
ops::Reshape2GradMaker);
348356
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp);
349-
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
350-
ops::ReshapeKernel, int, ops::ReshapeKernel,
351-
int64_t, ops::ReshapeKernel);
352-
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
353-
double, ops::ReshapeGradKernel, int,
354-
ops::ReshapeGradKernel, int64_t,
355-
ops::ReshapeGradKernel);
357+
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel<float>,
358+
double, ops::ReshapeKernel<double>, int,
359+
ops::ReshapeKernel<int>, int64_t,
360+
ops::ReshapeKernel<int64_t>);
361+
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float,
362+
ops::ReshapeGradKernel<float>, double,
363+
ops::ReshapeGradKernel<double>, int,
364+
ops::ReshapeGradKernel<int>, int64_t,
365+
ops::ReshapeGradKernel<int64_t>);
356366

357367
#ifdef PADDLE_WITH_CUDA
358-
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
359-
ops::ReshapeKernel, int, ops::ReshapeKernel,
360-
int64_t, ops::ReshapeKernel);
361-
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
362-
double, ops::ReshapeGradKernel, int,
363-
ops::ReshapeGradKernel, int64_t,
364-
ops::ReshapeGradKernel);
365-
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
366-
ops::ReshapeKernel, int, ops::ReshapeKernel,
367-
int64_t, ops::ReshapeKernel);
368-
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
369-
double, ops::ReshapeGradKernel, int,
370-
ops::ReshapeGradKernel, int64_t,
371-
ops::ReshapeGradKernel);
368+
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel<float>,
369+
double, ops::ReshapeKernel<double>, int,
370+
ops::ReshapeKernel<int>, int64_t,
371+
ops::ReshapeKernel<int64_t>);
372+
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float,
373+
ops::ReshapeGradKernel<float>, double,
374+
ops::ReshapeGradKernel<double>, int,
375+
ops::ReshapeGradKernel<int>, int64_t,
376+
ops::ReshapeGradKernel<int64_t>);
377+
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel<float>,
378+
double, ops::ReshapeKernel<double>, int,
379+
ops::ReshapeKernel<int>, int64_t,
380+
ops::ReshapeKernel<int64_t>);
381+
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float,
382+
ops::ReshapeGradKernel<float>, double,
383+
ops::ReshapeGradKernel<double>, int,
384+
ops::ReshapeGradKernel<int>, int64_t,
385+
ops::ReshapeGradKernel<int64_t>);
372386
#endif

paddle/fluid/operators/sequence_concat_op.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,12 @@ REGISTER_OPERATOR(sequence_concat, paddle::framework::OperatorWithKernel,
9090
paddle::framework::DefaultGradOpDescMaker<false>);
9191
template <typename T>
9292
using Kernel = op::SeqConcatKernel<paddle::platform::CPUDeviceContext, T>;
93-
REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel<float>, Kernel<double>);
93+
REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel<float>, Kernel<double>,
94+
Kernel<int64_t>);
9495
REGISTER_OPERATOR(sequence_concat_grad, paddle::framework::OperatorWithKernel,
9596
op::SeqConcatGradShapeInferer);
9697
template <typename T>
9798
using GradKernel =
9899
op::SeqConcatGradKernel<paddle::platform::CPUDeviceContext, T>;
99100
REGISTER_OP_CPU_KERNEL(sequence_concat_grad, GradKernel<float>,
100-
GradKernel<double>);
101+
GradKernel<double>, GradKernel<int64_t>);

0 commit comments

Comments
 (0)