Skip to content

Commit d9b202e

Browse files
committed
Move tensor copy src_ptr and dst_ptr check to TensorCopy function
test=develop
1 parent f408488 commit d9b202e

File tree

2 files changed

+43
-45
lines changed

2 files changed

+43
-45
lines changed

paddle/fluid/framework/tensor_util.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
114114
auto dst_ptr = dst->mutable_data(dst_place, src.type());
115115
auto size = src.numel() * SizeOfType(src.type());
116116
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
117+
if (src_ptr == dst_ptr) {
118+
VLOG(3) << "Skip copy the same data from " << src.place() << " to "
119+
<< dst_place;
120+
return;
121+
}
117122
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
118123
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
119124
}
@@ -132,6 +137,12 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
132137
platform::is_gpu_place(dst_place)) {
133138
auto src_gpu_place = boost::get<platform::CUDAPlace>(src_place);
134139
auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place);
140+
if (src_ptr == dst_ptr &&
141+
src_gpu_place.GetDeviceId() == dst_gpu_place.GetDeviceId()) {
142+
VLOG(3) << "Skip copy the same data from " << src.place() << " to "
143+
<< dst_place;
144+
return;
145+
}
135146
memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, nullptr);
136147
}
137148
#endif

paddle/fluid/operators/reshape_op.cc

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
195195
}
196196
};
197197

198-
template <typename T>
199198
class ReshapeKernel {
200199
public:
201200
void operator()(const framework::ExecutionContext &ctx) const {
@@ -228,25 +227,21 @@ class ReshapeKernel {
228227
"sequence_reshape op.");
229228
}
230229

231-
if (in->data<T>() !=
232-
reinterpret_cast<T *>(out->mutable_data(ctx.GetPlace(), in->type()))) {
233-
framework::TensorCopySync(*in, ctx.GetPlace(), out);
234-
}
230+
out->mutable_data(ctx.GetPlace(), in->type());
231+
framework::TensorCopySync(*in, ctx.GetPlace(), out);
235232
out->Resize(out_dims);
236233
}
237234
};
238235

239-
template <typename T>
240236
class ReshapeGradKernel {
241237
public:
242238
void operator()(const framework::ExecutionContext &ctx) const {
243239
auto *d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
244240
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
245241
auto in_dims = d_x->dims();
246242

247-
if (d_out->data<T>() != d_x->mutable_data(ctx.GetPlace(), d_out->type())) {
248-
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
249-
}
243+
d_x->mutable_data(ctx.GetPlace(), d_out->type());
244+
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
250245
d_x->Resize(in_dims);
251246
}
252247
};
@@ -341,46 +336,38 @@ namespace ops = paddle::operators;
341336
REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
342337
paddle::framework::DefaultGradOpDescMaker<true>);
343338
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp);
344-
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel<float>,
345-
double, ops::ReshapeKernel<double>, int,
346-
ops::ReshapeKernel<int>, int64_t,
347-
ops::ReshapeKernel<int64_t>);
348-
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float,
349-
ops::ReshapeGradKernel<float>, double,
350-
ops::ReshapeGradKernel<double>, int,
351-
ops::ReshapeGradKernel<int>, int64_t,
352-
ops::ReshapeGradKernel<int64_t>);
339+
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
340+
ops::ReshapeKernel, int, ops::ReshapeKernel,
341+
int64_t, ops::ReshapeKernel);
342+
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
343+
double, ops::ReshapeGradKernel, int,
344+
ops::ReshapeGradKernel, int64_t,
345+
ops::ReshapeGradKernel);
353346

354347
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
355348
ops::Reshape2GradMaker);
356349
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp);
357-
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel<float>,
358-
double, ops::ReshapeKernel<double>, int,
359-
ops::ReshapeKernel<int>, int64_t,
360-
ops::ReshapeKernel<int64_t>);
361-
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float,
362-
ops::ReshapeGradKernel<float>, double,
363-
ops::ReshapeGradKernel<double>, int,
364-
ops::ReshapeGradKernel<int>, int64_t,
365-
ops::ReshapeGradKernel<int64_t>);
350+
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
351+
ops::ReshapeKernel, int, ops::ReshapeKernel,
352+
int64_t, ops::ReshapeKernel);
353+
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
354+
double, ops::ReshapeGradKernel, int,
355+
ops::ReshapeGradKernel, int64_t,
356+
ops::ReshapeGradKernel);
366357

367358
#ifdef PADDLE_WITH_CUDA
368-
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel<float>,
369-
double, ops::ReshapeKernel<double>, int,
370-
ops::ReshapeKernel<int>, int64_t,
371-
ops::ReshapeKernel<int64_t>);
372-
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float,
373-
ops::ReshapeGradKernel<float>, double,
374-
ops::ReshapeGradKernel<double>, int,
375-
ops::ReshapeGradKernel<int>, int64_t,
376-
ops::ReshapeGradKernel<int64_t>);
377-
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel<float>,
378-
double, ops::ReshapeKernel<double>, int,
379-
ops::ReshapeKernel<int>, int64_t,
380-
ops::ReshapeKernel<int64_t>);
381-
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float,
382-
ops::ReshapeGradKernel<float>, double,
383-
ops::ReshapeGradKernel<double>, int,
384-
ops::ReshapeGradKernel<int>, int64_t,
385-
ops::ReshapeGradKernel<int64_t>);
359+
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
360+
ops::ReshapeKernel, int, ops::ReshapeKernel,
361+
int64_t, ops::ReshapeKernel);
362+
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
363+
double, ops::ReshapeGradKernel, int,
364+
ops::ReshapeGradKernel, int64_t,
365+
ops::ReshapeGradKernel);
366+
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
367+
ops::ReshapeKernel, int, ops::ReshapeKernel,
368+
int64_t, ops::ReshapeKernel);
369+
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
370+
double, ops::ReshapeGradKernel, int,
371+
ops::ReshapeGradKernel, int64_t,
372+
ops::ReshapeGradKernel);
386373
#endif

0 commit comments

Comments
 (0)