@@ -69,19 +69,47 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
69
69
PADDLE_ENFORCE (platform::is_gpu_place (ctx_place));
70
70
auto stream =
71
71
reinterpret_cast <const platform::CUDADeviceContext&>(ctx).stream ();
72
- memory::Copy (dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
72
+ if (platform::is_same_place (src_place, dst_place)) {
73
+ memory::Copy (dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
74
+ stream);
75
+ } else {
76
+ // NOTE(zcd): Because TensorCopy is an async operation, when the src_place
77
+ // and dst_place are two different GPU, to ensure that the operation can
78
+ // be carried out correctly, we should make ctx wait.
79
+ // If ctx_place and src_place are the same, we should add ctx.Wait()
80
+ // after memory::Copy; if ctx_place and dst_place are the same, we should
81
+ // add ctx.Wait() before memory::Copy.
82
+ if (platform::is_same_place (ctx_place, src_place)) {
83
+ memory::Copy (dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
84
+ stream);
85
+ ctx.Wait ();
86
+ } else if (platform::is_same_place (ctx_place, dst_place)) {
87
+ ctx.Wait ();
88
+ memory::Copy (dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
89
+ stream);
90
+ } else {
91
+ PADDLE_THROW (" ctx is not belong to dst_gpu_place or src_gpu_place." );
92
+ }
93
+ }
73
94
}
74
95
#endif
75
96
}
76
97
77
98
void TensorCopy (const Tensor& src, const platform::Place& dst_place,
78
99
Tensor* dst) {
100
+ // NOTE(zcd): If the src.place() and dst_place are two different GPU,
101
+ // the copy operation is carried out on the dst_place's stream. This is
102
+ // very important, because TensorCopy is an async operator, and in most
103
+ // case, once this copy operator returns, dst is to be used in dst_place's
104
+ // stream, if this copy operation is carried out on the src_place's stream,
105
+ // when dst is used in dst_place's stream the copy operation may be
106
+ // not completed.
79
107
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance ();
80
108
const platform::DeviceContext* dev_ctx;
81
- if (platform::is_gpu_place (src.place ())) {
82
- dev_ctx = pool.Get (src.place ());
83
- } else {
109
+ if (platform::is_gpu_place (dst_place)) {
84
110
dev_ctx = pool.Get (dst_place);
111
+ } else {
112
+ dev_ctx = pool.Get (src.place ());
85
113
}
86
114
TensorCopy (src, dst_place, *dev_ctx, dst);
87
115
}
0 commit comments