@@ -20,12 +20,29 @@ namespace paddle {
20
20
namespace operators {
21
21
namespace reader {
22
22
23
- static constexpr size_t kDoubleBufferSize = 2 ;
23
+ // 'Double buffer' means we shall maintain two batches of input data at the same
24
+ // time. So the kCacheSize shoul be at least 2.
25
+ static constexpr size_t kCacheSize = 2 ;
26
+ // There will be two bacthes out of the channel during training:
27
+ // 1. the one waiting to be sent to the channel
28
+ // 2. the one just be received from the channel, which is also being used by
29
+ // subsequent operators.
30
+ // So the channel size should be kChacheSize - 2
31
+ static constexpr size_t kChannelSize = 0 ; // kCacheSize - 2
24
32
25
33
class DoubleBufferReader : public framework ::DecoratedReader {
26
34
public:
27
35
struct Item {
28
36
Item () : ctx_(nullptr ) {}
37
+ Item (Item&& b) {
38
+ payloads_ = std::move (b.payloads_ );
39
+ ctx_ = std::move (b.ctx_ );
40
+ }
41
+ Item& operator =(Item&& b) {
42
+ payloads_ = std::move (b.payloads_ );
43
+ ctx_ = std::move (b.ctx_ );
44
+ return *this ;
45
+ }
29
46
30
47
std::vector<framework::LoDTensor> payloads_;
31
48
platform::DeviceContext* ctx_;
@@ -34,42 +51,44 @@ class DoubleBufferReader : public framework::DecoratedReader {
34
51
explicit DoubleBufferReader (
35
52
ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
36
53
: DecoratedReader(reader), place_(target_place) {
37
- for (size_t i = 0 ; i < kDoubleBufferSize ; ++i) {
38
- if (platform::is_gpu_place (place_)) {
39
54
#ifdef PADDLE_WITH_CUDA
55
+ for (size_t i = 0 ; i < kCacheSize ; ++i) {
56
+ if (platform::is_gpu_place (place_)) {
40
57
ctxs_.emplace_back (new platform::CUDADeviceContext (
41
58
boost::get<platform::CUDAPlace>(place_)));
42
- #endif
43
59
}
44
60
}
45
-
46
- start_thread ();
47
- }
48
-
49
- void start_thread () {
50
- buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize );
51
- prefetcher_ = std::thread ([this ] { PrefetchThreadFunc (); });
61
+ #endif
62
+ StartPrefetcher ();
52
63
}
53
64
65
+ bool HasNext () const override ;
54
66
void ReadNext (std::vector<framework::LoDTensor>* out) override ;
55
67
void ReInit () override ;
56
68
57
- ~DoubleBufferReader () {
58
- buffer_->Close ();
59
- prefetcher_.join ();
60
- delete buffer_;
69
+ ~DoubleBufferReader () { EndPrefetcher (); }
70
+
71
+ private:
72
+ void StartPrefetcher () {
73
+ channel_ = framework::MakeChannel<Item>(kChannelSize );
74
+ prefetcher_ = std::thread ([this ] { PrefetchThreadFunc (); });
61
75
}
62
76
63
- bool HasNext () const override ;
77
+ void EndPrefetcher () {
78
+ channel_->Close ();
79
+ if (prefetcher_.joinable ()) {
80
+ prefetcher_.join ();
81
+ }
82
+ delete channel_;
83
+ channel_ = nullptr ;
84
+ }
64
85
65
- private:
66
86
void PrefetchThreadFunc ();
67
87
68
88
std::thread prefetcher_;
69
- framework::Channel<Item>* buffer_ ;
89
+ framework::Channel<Item>* channel_ ;
70
90
platform::Place place_;
71
91
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
72
- mutable Item local_buffer_;
73
92
};
74
93
75
94
class CreateDoubleBufferReaderOp : public framework ::OperatorBase {
@@ -123,70 +142,70 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
123
142
}
124
143
};
125
144
145
+ bool DoubleBufferReader::HasNext () const {
146
+ while (!channel_->IsClosed () && !channel_->CanReceive ()) {
147
+ }
148
+ return channel_->CanReceive ();
149
+ }
150
+
126
151
void DoubleBufferReader::ReadNext (std::vector<framework::LoDTensor>* out) {
127
152
if (!HasNext ()) {
128
153
PADDLE_THROW (" There is no next data!" );
129
154
}
130
155
131
- if (local_buffer_.payloads_ .empty ()) {
132
- buffer_->Receive (&local_buffer_);
133
- }
134
- *out = local_buffer_.payloads_ ;
135
- local_buffer_.payloads_ .clear ();
136
- if (local_buffer_.ctx_ ) {
137
- local_buffer_.ctx_ ->Wait ();
156
+ Item batch;
157
+ channel_->Receive (&batch);
158
+ *out = batch.payloads_ ;
159
+ if (batch.ctx_ ) {
160
+ batch.ctx_ ->Wait ();
138
161
}
139
162
}
140
163
141
164
void DoubleBufferReader::ReInit () {
142
165
reader_->ReInit ();
143
- buffer_->Close ();
144
- prefetcher_.join ();
145
- delete buffer_;
146
- start_thread ();
166
+ EndPrefetcher ();
167
+ StartPrefetcher ();
147
168
}
148
169
149
170
void DoubleBufferReader::PrefetchThreadFunc () {
150
171
VLOG (5 ) << " A new prefetch thread starts." ;
151
- size_t gpu_ctx_offset = 0 ;
172
+ std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache (kCacheSize );
173
+ std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache (kCacheSize );
174
+ size_t cached_tensor_id = 0 ;
175
+
152
176
while (reader_->HasNext ()) {
153
177
Item batch;
154
- reader_->ReadNext (&batch.payloads_ );
178
+ auto & cpu_batch = cpu_tensor_cache[cached_tensor_id];
179
+ reader_->ReadNext (&cpu_batch);
155
180
if (platform::is_gpu_place (place_)) {
156
- std::vector<framework::LoDTensor> gpu_batch;
157
- auto & gpu_ctx = this ->ctxs_ [gpu_ctx_offset++];
158
- gpu_ctx_offset %= this ->ctxs_ .size ();
159
- gpu_batch.resize (batch.payloads_ .size ());
160
- for (size_t i = 0 ; i < batch.payloads_ .size (); ++i) {
161
- framework::TensorCopy (batch.payloads_ [i], place_, *gpu_ctx,
162
- &gpu_batch[i]);
163
- gpu_batch[i].set_lod (batch.payloads_ [i].lod ());
181
+ auto & gpu_batch = gpu_tensor_cache[cached_tensor_id];
182
+ auto * gpu_ctx = ctxs_[cached_tensor_id].get ();
183
+ gpu_batch.resize (cpu_batch.size ());
184
+ for (size_t i = 0 ; i < cpu_batch.size (); ++i) {
185
+ framework::TensorCopy (cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]);
186
+ gpu_batch[i].set_lod (cpu_batch[i].lod ());
164
187
}
165
- batch.ctx_ = gpu_ctx.get ();
166
- std::swap (gpu_batch, batch.payloads_ );
188
+ batch.payloads_ = gpu_batch;
189
+ batch.ctx_ = gpu_ctx;
190
+ } else {
191
+ // CPUPlace
192
+ batch.payloads_ = cpu_batch;
167
193
}
194
+ ++cached_tensor_id;
195
+ cached_tensor_id %= kCacheSize ;
168
196
169
197
try {
170
- buffer_ ->Send (&batch);
198
+ channel_ ->Send (&batch);
171
199
} catch (paddle::platform::EnforceNotMet e) {
172
200
VLOG (5 ) << " WARNING: The double buffer channel has been closed. The "
173
201
" prefetch thread will terminate." ;
174
202
break ;
175
203
}
176
204
}
177
- buffer_ ->Close ();
205
+ channel_ ->Close ();
178
206
VLOG (5 ) << " Prefetch thread terminates." ;
179
207
}
180
208
181
- bool DoubleBufferReader::HasNext () const {
182
- if (local_buffer_.payloads_ .empty ()) {
183
- bool ok = buffer_->Receive (&local_buffer_);
184
- return ok;
185
- } else {
186
- return true ;
187
- }
188
- }
189
-
190
209
} // namespace reader
191
210
} // namespace operators
192
211
} // namespace paddle
0 commit comments