23
23
#include " paddle/fluid/framework/details/cow_ptr.h"
24
24
#include " paddle/fluid/framework/tensor.h"
25
25
#include " paddle/fluid/framework/tensor_util.h"
26
+ #include " paddle/fluid/memory/malloc.h"
26
27
#include " paddle/fluid/memory/memcpy.h"
27
28
28
29
#include " glog/logging.h"
@@ -31,46 +32,6 @@ namespace paddle {
31
32
namespace framework {
32
33
33
34
#if defined(PADDLE_WITH_CUDA)
34
- namespace details {
35
- struct CUDABuffer {
36
- void *data_{nullptr };
37
- size_t size_{0 };
38
- platform::CUDAPlace place_;
39
-
40
- CUDABuffer () {}
41
- CUDABuffer (platform::Place place, size_t size)
42
- : size_(size), place_(boost::get<platform::CUDAPlace>(place)) {
43
- data_ = memory::Alloc (place_, size);
44
- }
45
-
46
- ~CUDABuffer () { ClearMemory (); }
47
-
48
- CUDABuffer (const CUDABuffer &o) = delete ;
49
- CUDABuffer &operator =(const CUDABuffer &o) = delete ;
50
-
51
- void Resize (platform::Place place, size_t size) {
52
- ClearMemory ();
53
- place_ = boost::get<platform::CUDAPlace>(place);
54
- data_ = memory::Alloc (place_, size);
55
- PADDLE_ENFORCE_NOT_NULL (data_);
56
- size_ = size;
57
- }
58
-
59
- void Swap (CUDABuffer &o) {
60
- std::swap (data_, o.data_ );
61
- std::swap (place_, o.place_ );
62
- std::swap (size_, o.size_ );
63
- }
64
-
65
- private:
66
- void ClearMemory () const {
67
- if (data_ != nullptr ) {
68
- memory::Free (place_, data_);
69
- }
70
- }
71
- };
72
- } // namespace details
73
-
74
35
// Vector<T> implements the std::vector interface, and can get Data or
75
36
// MutableData from any place. The data will be synced implicitly inside.
76
37
template <typename T>
@@ -103,8 +64,6 @@ class Vector {
103
64
o.ImmutableCPU ();
104
65
cpu_ = o.cpu_ ;
105
66
flag_ = kDataInCPU ;
106
- details::CUDABuffer null;
107
- gpu_.Swap (null);
108
67
return *this ;
109
68
}
110
69
@@ -199,7 +158,7 @@ class Vector {
199
158
PADDLE_ENFORCE (platform::is_gpu_place (place),
200
159
" CUDA Data must on CUDA place" );
201
160
ImmutableCUDA (place);
202
- return reinterpret_cast <T *>(gpu_. data_ );
161
+ return reinterpret_cast <T *>(gpu_-> ptr () );
203
162
}
204
163
205
164
// get cuda ptr. mutable
@@ -234,13 +193,11 @@ class Vector {
234
193
235
194
std::mutex &Mutex () const { return mtx_; }
236
195
237
- std::unique_ptr<platform::CUDAPlace> CUDAPlace () const {
238
- if (gpu_.data_ == nullptr ) {
239
- return nullptr ;
240
- } else {
241
- return std::unique_ptr<platform::CUDAPlace>(
242
- new platform::CUDAPlace (gpu_.place_ ));
243
- }
196
+ boost::optional<platform::CUDAPlace> CUDAPlace () const {
197
+ return gpu_ == nullptr
198
+ ? boost::none
199
+ : boost::optional<platform::CUDAPlace>(
200
+ boost::get<platform::CUDAPlace>(gpu_->place ()));
244
201
}
245
202
246
203
private:
@@ -254,13 +211,12 @@ class Vector {
254
211
void CopyToCPU () const {
255
212
// COPY GPU Data To CPU
256
213
auto *dev_ctx = static_cast <platform::CUDADeviceContext *>(
257
- platform::DeviceContextPool::Instance ().Get (
258
- platform::Place (gpu_.place_ )));
214
+ platform::DeviceContextPool::Instance ().Get (gpu_->place ()));
259
215
auto stream = dev_ctx->stream ();
260
- void *src = gpu_. data_ ;
216
+ void *src = gpu_-> ptr () ;
261
217
void *dst = cpu_.data ();
262
- memory::Copy (platform::CPUPlace (), dst, gpu_. place_ , src, gpu_. size_ ,
263
- stream);
218
+ memory::Copy (platform::CPUPlace (), dst, CUDAPlace (). get () , src,
219
+ gpu_-> size (), stream);
264
220
dev_ctx->Wait ();
265
221
}
266
222
@@ -277,8 +233,7 @@ class Vector {
277
233
CopyCPUDataToCUDA (place);
278
234
UnsetFlag (kDirty );
279
235
SetFlag (kDataInCUDA );
280
- } else if (IsInCUDA () &&
281
- !(boost::get<platform::CUDAPlace>(place) == gpu_.place_ )) {
236
+ } else if (IsInCUDA () && !(place == gpu_->place ())) {
282
237
PADDLE_THROW (" This situation should not happen" );
283
238
// Still dirty
284
239
} else {
@@ -290,7 +245,7 @@ class Vector {
290
245
// Even data is not dirty. However, data is not in CUDA. Copy data.
291
246
CopyCPUDataToCUDA (place);
292
247
SetFlag (kDataInCUDA );
293
- } else if (!(boost::get<platform::CUDAPlace>( place) == gpu_. place_ )) {
248
+ } else if (!(place == gpu_-> place () )) {
294
249
PADDLE_THROW (" This situation should not happen." );
295
250
} else {
296
251
// Not Dirty && DataInCUDA && Device is same
@@ -301,13 +256,13 @@ class Vector {
301
256
302
257
void CopyCPUDataToCUDA (const platform::Place &place) const {
303
258
void *src = cpu_.data ();
304
- gpu_. Resize (place, cpu_.size () * sizeof (T));
305
- void *dst = gpu_. data_ ;
259
+ gpu_ = memory::Alloc (place, cpu_.size () * sizeof (T));
260
+ void *dst = gpu_-> ptr () ;
306
261
auto *dev_ctx = static_cast <platform::CUDADeviceContext *>(
307
262
platform::DeviceContextPool::Instance ().Get (place));
308
263
auto stream = dev_ctx->stream ();
309
- memory::Copy (gpu_. place_ , dst, platform::CPUPlace (), src, gpu_. size_ ,
310
- stream);
264
+ memory::Copy (CUDAPlace (). get () , dst, platform::CPUPlace (), src,
265
+ gpu_-> size (), stream);
311
266
}
312
267
313
268
void ImmutableCPU () const {
@@ -329,7 +284,7 @@ class Vector {
329
284
bool IsInCPU () const { return flag_ & kDataInCPU ; }
330
285
331
286
mutable std::vector<T> cpu_;
332
- mutable details::CUDABuffer gpu_;
287
+ mutable memory::AllocationPtr gpu_;
333
288
mutable int flag_;
334
289
335
290
mutable std::mutex mtx_;
@@ -428,8 +383,8 @@ class Vector {
428
383
auto &mtx = m_.Data ().Mutex ();
429
384
std::lock_guard<std::mutex> guard (mtx);
430
385
auto cuda_place = m_.Data ().CUDAPlace ();
431
- if (cuda_place == nullptr ||
432
- * cuda_place == boost::get<platform::CUDAPlace>(place)) {
386
+ if (cuda_place == boost::none ||
387
+ cuda_place == boost::get<platform::CUDAPlace>(place)) {
433
388
return m_.Data ().CUDAData (place);
434
389
}
435
390
}
@@ -444,8 +399,8 @@ class Vector {
444
399
auto &mtx = m_.Data ().Mutex ();
445
400
std::lock_guard<std::mutex> guard (mtx);
446
401
auto cuda_place = m_.Data ().CUDAPlace ();
447
- if (cuda_place == nullptr ||
448
- * cuda_place == boost::get<platform::CUDAPlace>(place)) {
402
+ if (cuda_place == boost::none ||
403
+ cuda_place == boost::get<platform::CUDAPlace>(place)) {
449
404
return m_.MutableData ()->CUDAMutableData (place);
450
405
}
451
406
}
0 commit comments