Skip to content

Commit 3fcd16e

Browse files
committed
init double buffer
1 parent 86263b2 commit 3fcd16e

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

paddle/fluid/framework/reader.cc

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,5 +112,46 @@ void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
112112
out->push_back(out_tensor);
113113
}
114114
}
115+
116+
void DoubleBufferReader::ReadNext(std::vector<LoDTensor>* out) {
117+
std::unique_lock<std::mutex> lck(mtx_);
118+
while (write_pos_ == read_pos_) {
119+
buffer_not_empty_.wait(lck);
120+
}
121+
122+
out->clear();
123+
out->resize(buffer_[read_pos_].size());
124+
// TODO(fengjiayi): This copy shall be reduced.
125+
for (size_t i = 0; i < buffer_[read_pos_].size(); ++i) {
126+
TensorCopy(buffer_[read_pos_][i], platform::CPUPlace(), &out[i]);
127+
out[i].set_lod(buffer_[read_pos_][i].lod());
128+
}
129+
130+
++read_pos_;
131+
if (read_pos_ >= kDoubleBufferSize) {
132+
read_pos_ = 0;
133+
}
134+
buffer_not_full_.notify_all();
135+
}
136+
137+
bool DoubleBufferReader::HasNext() {
138+
return reader_->HasNext() || !buffer_.empty();
139+
}
140+
141+
void DoubleBufferReader::ProducerThreadFunc() {
142+
while (reader_->HasNext()) {
143+
std::unique_lock<std::mutex> lck(mtx);
144+
while (((write_pos_ + 1) % kDoubleBufferSize) == read_pos_) {
145+
buffer_not_full_.wait(lck);
146+
}
147+
reader_->ReadNext(&buffer_[write_pos_]);
148+
++write_pos_;
149+
if (write_pos_ >= kDoubleBufferSize) {
150+
write_pos_ = 0;
151+
}
152+
buffer_not_empty_.notify_all();
153+
}
154+
}
155+
115156
} // namespace framework
116157
} // namespace paddle

paddle/fluid/framework/reader.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616

1717
#include "paddle/fluid/framework/ddim.h"
1818
#include "paddle/fluid/framework/lod_tensor_array.h"
19+
#include "paddle/fluid/framework/threadpool.h"
1920

2021
namespace paddle {
2122
namespace framework {
2223

24+
static constexpr size_t kDoubleBufferSize = 3;
25+
2326
class ReaderBase {
2427
public:
2528
explicit ReaderBase(const std::vector<DDim>& shapes) : shapes_(shapes) {
@@ -135,6 +138,28 @@ class BatchReader : public DecoratedReader {
135138
std::vector<std::vector<LoDTensor>> buffer_;
136139
};
137140

141+
class DoubleBufferReader : public DecoratedReader {
142+
public:
143+
DoubleBufferReader(ReaderBase* reader)
144+
: DecoratedReader(reader), buffer_(kDoubleBufferSize) {
145+
framework::Async(std::bind(&DoubleBufferReader::ProducerThreadFunc, this));
146+
}
147+
148+
void ReadNext(std::vector<LoDTensor>* out) override;
149+
bool HasNext() const override;
150+
151+
private:
152+
void ProducerThreadFunc();
153+
154+
std::vector<std::vector<LoDTensor>> buffer_;
155+
size_t write_pos_;
156+
size_t read_pos_;
157+
158+
std::mutex mtx_;
159+
std::condition_variable buffer_not_full_;
160+
std::condition_variable buffer_not_empty_;
161+
};
162+
138163
// The ReaderHolder is used as readers' unified wrapper,
139164
// making it easier to access different type readers in Variables.
140165
class ReaderHolder {

0 commit comments

Comments
 (0)