26
26
namespace paddle {
27
27
namespace framework {
28
28
29
+ #if defined(PADDLE_WITH_CUDA)
29
30
// Vector<T> implements the std::vector interface, and can get Data or
30
31
// MutableData from any place. The data will be synced implicitly inside.
31
32
template <typename T>
@@ -37,11 +38,11 @@ class Vector {
37
38
Vector () { InitEmpty (); }
38
39
39
40
// Fill vector with value. The vector size is `count`.
40
- explicit Vector (size_t count, const T& value = T()) {
41
+ explicit Vector (size_t count, const T & value = T()) {
41
42
InitEmpty ();
42
43
if (count != 0 ) {
43
44
resize (count);
44
- T* ptr = begin ();
45
+ T * ptr = begin ();
45
46
for (size_t i = 0 ; i < count; ++i) {
46
47
ptr[i] = value;
47
48
}
@@ -59,7 +60,7 @@ class Vector {
59
60
60
61
// implicit cast from std::vector.
61
62
template <typename U>
62
- Vector (const std::vector<U>& dat) { // NOLINT
63
+ Vector (const std::vector<U> & dat) { // NOLINT
63
64
if (dat.size () == 0 ) {
64
65
InitEmpty ();
65
66
} else {
@@ -68,10 +69,10 @@ class Vector {
68
69
}
69
70
70
71
// Copy ctor
71
- Vector (const Vector<T>& other) { this ->operator =(other); }
72
+ Vector (const Vector<T> & other) { this ->operator =(other); }
72
73
73
74
// Copy operator
74
- Vector<T>& operator =(const Vector<T>& other) {
75
+ Vector<T> & operator =(const Vector<T> & other) {
75
76
if (other.size () != 0 ) {
76
77
this ->InitByIter (other.size (), other.begin (), other.end ());
77
78
} else {
@@ -81,7 +82,7 @@ class Vector {
81
82
}
82
83
83
84
// Move ctor
84
- Vector (Vector<T>&& other) {
85
+ Vector (Vector<T> && other) {
85
86
this ->size_ = other.size_ ;
86
87
this ->flag_ = other.flag_ ;
87
88
if (other.cuda_vec_ .memory_size ()) {
@@ -93,57 +94,57 @@ class Vector {
93
94
}
94
95
95
96
// CPU data access method. Mutable.
96
- T& operator [](size_t i) {
97
+ T & operator [](size_t i) {
97
98
MutableCPU ();
98
- return const_cast <T*>(cpu_vec_.data <T>())[i];
99
+ return const_cast <T *>(cpu_vec_.data <T>())[i];
99
100
}
100
101
101
102
// CPU data access method. Immutable.
102
- const T& operator [](size_t i) const {
103
+ const T & operator [](size_t i) const {
103
104
ImmutableCPU ();
104
105
return cpu_vec_.data <T>()[i];
105
106
}
106
107
107
108
// std::vector iterator methods. Based on CPU data access method
108
109
size_t size () const { return size_; }
109
110
110
- T* begin () { return capacity () == 0 ? &EmptyDummy () : &this ->operator [](0 ); }
111
+ T * begin () { return capacity () == 0 ? &EmptyDummy () : &this ->operator [](0 ); }
111
112
112
- T* end () {
113
+ T * end () {
113
114
return capacity () == 0 ? &EmptyDummy () : &this ->operator [](size ());
114
115
}
115
116
116
- T& front () { return *begin (); }
117
+ T & front () { return *begin (); }
117
118
118
- T& back () {
119
+ T & back () {
119
120
auto it = end ();
120
121
--it;
121
122
return *it;
122
123
}
123
124
124
- const T* begin () const {
125
+ const T * begin () const {
125
126
return capacity () == 0 ? &EmptyDummy () : &this ->operator [](0 );
126
127
}
127
128
128
- const T* end () const {
129
+ const T * end () const {
129
130
return capacity () == 0 ? &EmptyDummy () : &this ->operator [](size ());
130
131
}
131
132
132
- const T* cbegin () const { return begin (); }
133
+ const T * cbegin () const { return begin (); }
133
134
134
- const T* cend () const { return end (); }
135
+ const T * cend () const { return end (); }
135
136
136
- const T& back () const {
137
+ const T & back () const {
137
138
auto it = end ();
138
139
--it;
139
140
return *it;
140
141
}
141
142
142
- T* data () { return begin (); }
143
+ T * data () { return begin (); }
143
144
144
- const T* data () const { return begin (); }
145
+ const T * data () const { return begin (); }
145
146
146
- const T& front () const { return *begin (); }
147
+ const T & front () const { return *begin (); }
147
148
// end of std::vector iterator methods
148
149
149
150
// assign this from iterator.
@@ -169,7 +170,7 @@ class Vector {
169
170
void Extend (It begin, It end) {
170
171
size_t pre_size = size_;
171
172
resize (pre_size + (end - begin));
172
- T* ptr = this ->begin () + pre_size;
173
+ T * ptr = this ->begin () + pre_size;
173
174
for (; begin < end; ++begin, ++ptr) {
174
175
*ptr = *begin;
175
176
}
@@ -183,9 +184,9 @@ class Vector {
183
184
MutableCPU ();
184
185
Tensor cpu_tensor;
185
186
platform::Place cpu = platform::CPUPlace ();
186
- T* ptr = cpu_tensor.mutable_data <T>(
187
+ T * ptr = cpu_tensor.mutable_data <T>(
187
188
framework::make_ddim ({static_cast <int64_t >(size)}), cpu);
188
- const T* old_ptr =
189
+ const T * old_ptr =
189
190
cpu_vec_.memory_size () == 0 ? nullptr : cpu_vec_.data <T>();
190
191
if (old_ptr != nullptr ) {
191
192
std::copy (old_ptr, old_ptr + size_, ptr);
@@ -196,18 +197,18 @@ class Vector {
196
197
}
197
198
198
199
// get cuda ptr. immutable
199
- const T* CUDAData (platform::Place place) const {
200
+ const T * CUDAData (platform::Place place) const {
200
201
PADDLE_ENFORCE (platform::is_gpu_place (place),
201
202
" CUDA Data must on CUDA place" );
202
203
ImmutableCUDA (place);
203
204
return cuda_vec_.data <T>();
204
205
}
205
206
206
207
// get cuda ptr. mutable
207
- T* CUDAMutableData (platform::Place place) {
208
- const T* ptr = CUDAData (place);
208
+ T * CUDAMutableData (platform::Place place) {
209
+ const T * ptr = CUDAData (place);
209
210
flag_ = kDirty | kDataInCUDA ;
210
- return const_cast <T*>(ptr);
211
+ return const_cast <T *>(ptr);
211
212
}
212
213
213
214
// clear
@@ -228,7 +229,7 @@ class Vector {
228
229
}
229
230
230
231
// the unify method to access CPU or CUDA data. immutable.
231
- const T* Data (platform::Place place) const {
232
+ const T * Data (platform::Place place) const {
232
233
if (platform::is_gpu_place (place)) {
233
234
return CUDAData (place);
234
235
} else {
@@ -237,7 +238,7 @@ class Vector {
237
238
}
238
239
239
240
// the unify method to access CPU or CUDA data. mutable.
240
- T* MutableData (platform::Place place) {
241
+ T * MutableData (platform::Place place) {
241
242
if (platform::is_gpu_place (place)) {
242
243
return CUDAMutableData (place);
243
244
} else {
@@ -253,7 +254,7 @@ class Vector {
253
254
return result;
254
255
}
255
256
256
- bool operator ==(const Vector<T>& other) const {
257
+ bool operator ==(const Vector<T> & other) const {
257
258
if (size () != other.size ()) return false ;
258
259
auto it1 = cbegin ();
259
260
auto it2 = other.cbegin ();
@@ -274,7 +275,7 @@ class Vector {
274
275
template <typename Iter>
275
276
void InitByIter (size_t size, Iter begin, Iter end) {
276
277
platform::Place cpu = platform::CPUPlace ();
277
- T* ptr = this ->cpu_vec_ .template mutable_data <T>(
278
+ T * ptr = this ->cpu_vec_ .template mutable_data <T>(
278
279
framework::make_ddim ({static_cast <int64_t >(size)}), cpu);
279
280
for (size_t i = 0 ; i < size; ++i) {
280
281
*ptr++ = *begin++;
@@ -368,7 +369,7 @@ class Vector {
368
369
}
369
370
}
370
371
371
- static T& EmptyDummy () {
372
+ static T & EmptyDummy () {
372
373
static T dummy = T ();
373
374
return dummy;
374
375
}
@@ -379,5 +380,53 @@ class Vector {
379
380
size_t size_;
380
381
};
381
382
382
- } // namespace framework
383
+ #else // PADDLE_WITH_CUDA
384
+
385
+ template <typename T>
386
+ class CPUVector : public std ::vector<T, std::allocator<T>> {
387
+ public:
388
+ CPUVector () : std::vector<T>() {}
389
+ CPUVector (size_t count, const T &value = T())
390
+ : std::vector<T>(count, value) {}
391
+ CPUVector (std::initializer_list<T> init) : std::vector<T>(init) {}
392
+ CPUVector (const std::vector<T> &other) : std::vector<T>(other) {}
393
+ explicit CPUVector (const CPUVector<T> &other) : std::vector<T>(other) {}
394
+ CPUVector (CPUVector<T> &&other) : std::vector<T>(std::move(other)) {}
395
+ CPUVector (std::vector<T> &&other) : std::vector<T>(std::move(other)) {}
396
+ CPUVector &operator =(const CPUVector &other) {
397
+ this ->assign (other.begin (), other.end ());
398
+ return *this ;
399
+ }
400
+ CPUVector &operator =(const std::vector<T> &other) {
401
+ this ->assign (other.begin (), other.end ());
402
+ return *this ;
403
+ }
404
+
405
+ friend std::ostream &operator <<(std::ostream &os, const CPUVector<T> &other) {
406
+ std::stringstream ss;
407
+ for (auto v : other) {
408
+ os << v << " " ;
409
+ }
410
+ return os;
411
+ }
412
+
413
+ void resize (size_t size) { this ->resize (size); }
414
+
415
+ T &operator [](size_t id) { return this ->at (id); }
416
+
417
+ const T &operator [](size_t id) const { return this ->at (id); }
418
+
419
+ template <typename D>
420
+ void Extend (const D &begin, const D &end) {
421
+ this ->reserve (this ->size () + size_t (end - begin));
422
+ this ->insert (this ->end (), begin, end);
423
+ }
424
+ };
425
+
426
+ template <typename T>
427
+ using Vector = CPUVector<T>;
428
+
429
+ #endif // PADDLE_WITH_CUDA
430
+
431
+ }; // namespace framework
383
432
} // namespace paddle
0 commit comments