Skip to content

Commit 8d76cf3

Browse files
author
chengduo
authored
Fix TensorCopy bug (#11822)
* Fix tensorcopy bug * follow comment * Refine TensorCopy
1 parent 5988d0c commit 8d76cf3

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

paddle/fluid/framework/parallel_executor.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,6 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
253253
t->set_lod(lod_tensors[j].lod());
254254
}
255255
}
256-
for (auto &p : member_->places_) {
257-
platform::DeviceContextPool::Instance().Get(p)->Wait();
258-
}
259256
}
260257

261258
ParallelExecutor::~ParallelExecutor() {

paddle/fluid/framework/tensor_util.cc

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,19 +69,47 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
6969
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
7070
auto stream =
7171
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
72-
memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
72+
if (platform::is_same_place(src_place, dst_place)) {
73+
memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
74+
stream);
75+
} else {
76+
// NOTE(zcd): Because TensorCopy is an async operation, when the src_place
77+
// and dst_place are two different GPU, to ensure that the operation can
78+
// be carried out correctly, we should make ctx wait.
79+
// If ctx_place and src_place are the same, we should add ctx.Wait()
80+
// after memory::Copy; if ctx_place and dst_place are the same, we should
81+
// add ctx.Wait() before memory::Copy.
82+
if (platform::is_same_place(ctx_place, src_place)) {
83+
memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
84+
stream);
85+
ctx.Wait();
86+
} else if (platform::is_same_place(ctx_place, dst_place)) {
87+
ctx.Wait();
88+
memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
89+
stream);
90+
} else {
91+
PADDLE_THROW("ctx is not belong to dst_gpu_place or src_gpu_place.");
92+
}
93+
}
7394
}
7495
#endif
7596
}
7697

7798
void TensorCopy(const Tensor& src, const platform::Place& dst_place,
7899
Tensor* dst) {
100+
// NOTE(zcd): If the src.place() and dst_place are two different GPU,
101+
// the copy operation is carried out on the dst_place's stream. This is
102+
// very important, because TensorCopy is an async operator, and in most
103+
// case, once this copy operator returns, dst is to be used in dst_place's
104+
// stream, if this copy operation is carried out on the src_place's stream,
105+
// when dst is used in dst_place's stream the copy operation may be
106+
// not completed.
79107
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
80108
const platform::DeviceContext* dev_ctx;
81-
if (platform::is_gpu_place(src.place())) {
82-
dev_ctx = pool.Get(src.place());
83-
} else {
109+
if (platform::is_gpu_place(dst_place)) {
84110
dev_ctx = pool.Get(dst_place);
111+
} else {
112+
dev_ctx = pool.Get(src.place());
85113
}
86114
TensorCopy(src, dst_place, *dev_ctx, dst);
87115
}

0 commit comments

Comments
 (0)