@@ -58,12 +58,13 @@ void GetTensorPayload(framework::Variable* var,
58
58
if (platform::is_gpu_place (ctx.GetPlace ())) {
59
59
#ifdef PADDLE_WITH_CUDA
60
60
PADDLE_ENFORCE (platform::is_gpu_place (tensor.place ()));
61
- platform::CPUPlace cpu ;
61
+ platform::CUDAPinnedPlace cuda_pinned ;
62
62
auto & gpu_dev_ctx = static_cast <const platform::CUDADeviceContext&>(ctx);
63
63
auto copy_size = tensor.numel () * framework::SizeOfType (tensor.type ());
64
- *payload = memory::Alloc (cpu , copy_size);
64
+ *payload = memory::Alloc (cuda_pinned , copy_size);
65
65
66
- memory::Copy (cpu, *payload, boost::get<platform::CUDAPlace>(tensor.place ()),
66
+ memory::Copy (cuda_pinned, *payload,
67
+ boost::get<platform::CUDAPlace>(tensor.place ()),
67
68
reinterpret_cast <const void *>(tensor.data <void >()), copy_size,
68
69
gpu_dev_ctx.stream ());
69
70
ctx.Wait ();
@@ -90,11 +91,11 @@ void GetSelectedRowsPayload(framework::Variable* var,
90
91
auto * tensor = slr->mutable_value ();
91
92
if (platform::is_gpu_place (ctx.GetPlace ())) {
92
93
#ifdef PADDLE_WITH_CUDA
93
- platform::CPUPlace cpu ;
94
+ platform::CUDAPinnedPlace cuda_pinned ;
94
95
auto & gpu_dev_ctx = static_cast <const platform::CUDADeviceContext&>(ctx);
95
96
auto copy_size = tensor->numel () * framework::SizeOfType (tensor->type ());
96
- *payload = memory::Alloc (cpu , copy_size);
97
- memory::Copy (cpu , *payload,
97
+ *payload = memory::Alloc (cuda_pinned , copy_size);
98
+ memory::Copy (cuda_pinned , *payload,
98
99
boost::get<platform::CUDAPlace>(tensor->place ()),
99
100
reinterpret_cast <const void *>(tensor->data <void >()), copy_size,
100
101
gpu_dev_ctx.stream ());
@@ -145,8 +146,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
145
146
// GPU data is copied to CPU buffer when sending,
146
147
// free the buffer when possible.
147
148
destroy_callback = [](void * backing) {
148
- platform::CPUPlace cpu ;
149
- memory::Free (cpu , backing);
149
+ platform::CUDAPinnedPlace cuda_pinned ;
150
+ memory::Free (cuda_pinned , backing);
150
151
};
151
152
}
152
153
0 commit comments