Skip to content

Commit 899827f

Browse files
authored
Merge pull request #9535 from reyoung/feature/fix_double_buffer
Add local cache of double buffer reader
2 parents 3fd9266 + b94f24d commit 899827f

File tree

1 file changed

+72
-53
lines changed

1 file changed

+72
-53
lines changed

paddle/fluid/operators/reader/create_double_buffer_reader_op.cc

Lines changed: 72 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,29 @@ namespace paddle {
2020
namespace operators {
2121
namespace reader {
2222

23-
static constexpr size_t kDoubleBufferSize = 2;
23+
// 'Double buffer' means we shall maintain two batches of input data at the same
24+
// time. So the kCacheSize shoul be at least 2.
25+
static constexpr size_t kCacheSize = 2;
26+
// There will be two bacthes out of the channel during training:
27+
// 1. the one waiting to be sent to the channel
28+
// 2. the one just be received from the channel, which is also being used by
29+
// subsequent operators.
30+
// So the channel size should be kChacheSize - 2
31+
static constexpr size_t kChannelSize = 0; // kCacheSize - 2
2432

2533
class DoubleBufferReader : public framework::DecoratedReader {
2634
public:
2735
struct Item {
2836
Item() : ctx_(nullptr) {}
37+
Item(Item&& b) {
38+
payloads_ = std::move(b.payloads_);
39+
ctx_ = std::move(b.ctx_);
40+
}
41+
Item& operator=(Item&& b) {
42+
payloads_ = std::move(b.payloads_);
43+
ctx_ = std::move(b.ctx_);
44+
return *this;
45+
}
2946

3047
std::vector<framework::LoDTensor> payloads_;
3148
platform::DeviceContext* ctx_;
@@ -34,42 +51,44 @@ class DoubleBufferReader : public framework::DecoratedReader {
3451
explicit DoubleBufferReader(
3552
ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
3653
: DecoratedReader(reader), place_(target_place) {
37-
for (size_t i = 0; i < kDoubleBufferSize; ++i) {
38-
if (platform::is_gpu_place(place_)) {
3954
#ifdef PADDLE_WITH_CUDA
55+
for (size_t i = 0; i < kCacheSize; ++i) {
56+
if (platform::is_gpu_place(place_)) {
4057
ctxs_.emplace_back(new platform::CUDADeviceContext(
4158
boost::get<platform::CUDAPlace>(place_)));
42-
#endif
4359
}
4460
}
45-
46-
start_thread();
47-
}
48-
49-
void start_thread() {
50-
buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize);
51-
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
61+
#endif
62+
StartPrefetcher();
5263
}
5364

65+
bool HasNext() const override;
5466
void ReadNext(std::vector<framework::LoDTensor>* out) override;
5567
void ReInit() override;
5668

57-
~DoubleBufferReader() {
58-
buffer_->Close();
59-
prefetcher_.join();
60-
delete buffer_;
69+
~DoubleBufferReader() { EndPrefetcher(); }
70+
71+
private:
72+
void StartPrefetcher() {
73+
channel_ = framework::MakeChannel<Item>(kChannelSize);
74+
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
6175
}
6276

63-
bool HasNext() const override;
77+
void EndPrefetcher() {
78+
channel_->Close();
79+
if (prefetcher_.joinable()) {
80+
prefetcher_.join();
81+
}
82+
delete channel_;
83+
channel_ = nullptr;
84+
}
6485

65-
private:
6686
void PrefetchThreadFunc();
6787

6888
std::thread prefetcher_;
69-
framework::Channel<Item>* buffer_;
89+
framework::Channel<Item>* channel_;
7090
platform::Place place_;
7191
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
72-
mutable Item local_buffer_;
7392
};
7493

7594
class CreateDoubleBufferReaderOp : public framework::OperatorBase {
@@ -123,70 +142,70 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
123142
}
124143
};
125144

145+
bool DoubleBufferReader::HasNext() const {
146+
while (!channel_->IsClosed() && !channel_->CanReceive()) {
147+
}
148+
return channel_->CanReceive();
149+
}
150+
126151
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
127152
if (!HasNext()) {
128153
PADDLE_THROW("There is no next data!");
129154
}
130155

131-
if (local_buffer_.payloads_.empty()) {
132-
buffer_->Receive(&local_buffer_);
133-
}
134-
*out = local_buffer_.payloads_;
135-
local_buffer_.payloads_.clear();
136-
if (local_buffer_.ctx_) {
137-
local_buffer_.ctx_->Wait();
156+
Item batch;
157+
channel_->Receive(&batch);
158+
*out = batch.payloads_;
159+
if (batch.ctx_) {
160+
batch.ctx_->Wait();
138161
}
139162
}
140163

141164
void DoubleBufferReader::ReInit() {
142165
reader_->ReInit();
143-
buffer_->Close();
144-
prefetcher_.join();
145-
delete buffer_;
146-
start_thread();
166+
EndPrefetcher();
167+
StartPrefetcher();
147168
}
148169

149170
void DoubleBufferReader::PrefetchThreadFunc() {
150171
VLOG(5) << "A new prefetch thread starts.";
151-
size_t gpu_ctx_offset = 0;
172+
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache(kCacheSize);
173+
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache(kCacheSize);
174+
size_t cached_tensor_id = 0;
175+
152176
while (reader_->HasNext()) {
153177
Item batch;
154-
reader_->ReadNext(&batch.payloads_);
178+
auto& cpu_batch = cpu_tensor_cache[cached_tensor_id];
179+
reader_->ReadNext(&cpu_batch);
155180
if (platform::is_gpu_place(place_)) {
156-
std::vector<framework::LoDTensor> gpu_batch;
157-
auto& gpu_ctx = this->ctxs_[gpu_ctx_offset++];
158-
gpu_ctx_offset %= this->ctxs_.size();
159-
gpu_batch.resize(batch.payloads_.size());
160-
for (size_t i = 0; i < batch.payloads_.size(); ++i) {
161-
framework::TensorCopy(batch.payloads_[i], place_, *gpu_ctx,
162-
&gpu_batch[i]);
163-
gpu_batch[i].set_lod(batch.payloads_[i].lod());
181+
auto& gpu_batch = gpu_tensor_cache[cached_tensor_id];
182+
auto* gpu_ctx = ctxs_[cached_tensor_id].get();
183+
gpu_batch.resize(cpu_batch.size());
184+
for (size_t i = 0; i < cpu_batch.size(); ++i) {
185+
framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]);
186+
gpu_batch[i].set_lod(cpu_batch[i].lod());
164187
}
165-
batch.ctx_ = gpu_ctx.get();
166-
std::swap(gpu_batch, batch.payloads_);
188+
batch.payloads_ = gpu_batch;
189+
batch.ctx_ = gpu_ctx;
190+
} else {
191+
// CPUPlace
192+
batch.payloads_ = cpu_batch;
167193
}
194+
++cached_tensor_id;
195+
cached_tensor_id %= kCacheSize;
168196

169197
try {
170-
buffer_->Send(&batch);
198+
channel_->Send(&batch);
171199
} catch (paddle::platform::EnforceNotMet e) {
172200
VLOG(5) << "WARNING: The double buffer channel has been closed. The "
173201
"prefetch thread will terminate.";
174202
break;
175203
}
176204
}
177-
buffer_->Close();
205+
channel_->Close();
178206
VLOG(5) << "Prefetch thread terminates.";
179207
}
180208

181-
bool DoubleBufferReader::HasNext() const {
182-
if (local_buffer_.payloads_.empty()) {
183-
bool ok = buffer_->Receive(&local_buffer_);
184-
return ok;
185-
} else {
186-
return true;
187-
}
188-
}
189-
190209
} // namespace reader
191210
} // namespace operators
192211
} // namespace paddle

0 commit comments

Comments
 (0)