Skip to content

Commit 8c54f1f

Browse files
author
chengduo
authored
Merge pull request #10906 from chengduoZH/fix_data_trans
Fix DataTransFunc
2 parents 7d1332f + 17a076d commit 8c54f1f

File tree

1 file changed

+16
-22
lines changed

1 file changed

+16
-22
lines changed

paddle/fluid/framework/data_device_transform.cc

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,25 @@ limitations under the License. */
1616
namespace paddle {
1717
namespace framework {
1818

19-
static const platform::DeviceContext* GetDeviceContext(
20-
const platform::Place& src_place, const platform::Place& dst_place) {
21-
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
22-
23-
if (platform::is_gpu_place(src_place) && platform::is_cpu_place(dst_place)) {
24-
return pool.Get(src_place);
25-
} else if (platform::is_cpu_place(src_place) &&
26-
platform::is_gpu_place(dst_place)) {
27-
return pool.Get(dst_place);
28-
} else {
29-
PADDLE_THROW(
30-
"Currently, model parallelism is only supported between CPU and CUDA");
31-
}
32-
}
33-
34-
void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
35-
Tensor* out) {
19+
void TransDataDevice(const Tensor &in, const platform::Place &dst_place,
20+
Tensor *out) {
3621
VLOG(3) << "DeviceTransform in, src_place " << in.place()
3722
<< " dst_place: " << dst_place;
38-
auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
3923

40-
TensorCopy(in, dst_place, *dev_ctx, out);
41-
if (platform::is_gpu_place(in.place()) && platform::is_cpu_place(dst_place)) {
42-
dev_ctx->Wait();
43-
}
24+
PADDLE_ENFORCE_NE(
25+
in.place().which(), dst_place.which(),
26+
"Currently, model parallelism is only supported between CPU and CUDA");
27+
28+
// FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and
29+
// the enforced checkings have been done in GetDeviceContext, so the
30+
// `dev_ctx->Wait()` is necessary. But `dev_ctx->Wait()` will make the program
31+
// slow, especially when the number of elements is little, for example,
32+
// the elements of learning rate are one and it's CPU side.
33+
// One solution is to use a CUDA kernel to complete the copy operation when
34+
// the transforming is from CPU to GPU and the number of elements is little.
35+
// But the embarrassment is that this solution this solution makes training
36+
// slower.
37+
TensorCopySync(in, dst_place, out);
4438
}
4539

4640
} // namespace framework

0 commit comments

Comments
 (0)