Skip to content

Commit 35e1e0d

Browse files
committed
uses channel to replace the traditional buffer
1 parent b3a11fd commit 35e1e0d

File tree

1 file changed

+31
-46
lines changed

1 file changed

+31
-46
lines changed

paddle/fluid/operators/reader/create_double_buffer_reader_op.cc

Lines changed: 31 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -12,42 +12,35 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include <condition_variable>
16-
#include <mutex>
1715
#include <thread>
16+
#include "paddle/fluid/framework/channel.h"
1817
#include "paddle/fluid/operators/reader/reader_op_registry.h"
1918

2019
namespace paddle {
2120
namespace operators {
2221
namespace reader {
2322

24-
static constexpr size_t kDoubleBufferSize = 3;
23+
static constexpr size_t kDoubleBufferSize = 2;
2524

2625
class DoubleBufferReader : public framework::DecoratedReader {
2726
public:
2827
explicit DoubleBufferReader(ReaderBase* reader)
2928
: DecoratedReader(reader),
30-
buffer_(kDoubleBufferSize),
31-
write_pos_(0),
32-
read_pos_(0) {
33-
std::thread prefetch(
34-
std::bind(&DoubleBufferReader::PrefetchThreadFunc, this));
29+
buffer_(framework::MakeChannel<std::vector<framework::LoDTensor>>(
30+
kDoubleBufferSize)) {
31+
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
3532
prefetch.detach();
3633
}
3734

3835
void ReadNext(std::vector<framework::LoDTensor>* out) override;
39-
bool HasNext() const override;
36+
void ReInit() override;
37+
38+
~DoubleBufferReader() { buffer_->Close(); }
4039

4140
private:
4241
void PrefetchThreadFunc();
4342

44-
std::vector<std::vector<framework::LoDTensor>> buffer_;
45-
size_t write_pos_;
46-
size_t read_pos_;
47-
48-
std::mutex mtx_;
49-
std::condition_variable buffer_not_full_;
50-
std::condition_variable buffer_not_empty_;
43+
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
5144
};
5245

5346
class CreateDoubleBufferReaderOp : public framework::OperatorBase {
@@ -80,44 +73,36 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
8073
};
8174

8275
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
83-
std::unique_lock<std::mutex> lck(mtx_);
84-
while (write_pos_ == read_pos_) {
85-
buffer_not_empty_.wait(lck);
86-
}
87-
8876
out->clear();
89-
out->reserve(buffer_[read_pos_].size());
90-
// TODO(fengjiayi): This copy shall be reduced.
91-
for (size_t i = 0; i < buffer_[read_pos_].size(); ++i) {
92-
framework::LoDTensor dst;
93-
TensorCopy(buffer_[read_pos_][i], platform::CPUPlace(), &dst);
94-
dst.set_lod(buffer_[read_pos_][i].lod());
95-
out->push_back(dst);
96-
}
97-
98-
++read_pos_;
99-
if (read_pos_ >= kDoubleBufferSize) {
100-
read_pos_ = 0;
101-
}
102-
buffer_not_full_.notify_all();
77+
buffer_->Receive(out);
10378
}
10479

105-
bool DoubleBufferReader::HasNext() const {
106-
return reader_->HasNext() || !buffer_.empty();
80+
void DoubleBufferReader::ReInit() {
81+
reader_->ReInit();
82+
buffer_->Close();
83+
// The existing prefetch thread will terminate for the buffer_ is closed.
84+
buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>(
85+
kDoubleBufferSize);
86+
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
87+
prefetch.detach();
10788
}
10889

10990
void DoubleBufferReader::PrefetchThreadFunc() {
110-
while (reader_->HasNext()) {
111-
std::unique_lock<std::mutex> lck(mtx_);
112-
while (((write_pos_ + 1) % kDoubleBufferSize) == read_pos_) {
113-
buffer_not_full_.wait(lck);
91+
VLOG(5) << "A new prefetch thread starts.";
92+
while (true) {
93+
std::vector<framework::LoDTensor> batch;
94+
reader_->ReadNext(&batch);
95+
if (batch.empty()) {
96+
// EOF
97+
buffer_->Close();
98+
VLOG(5) << "Reached the end of the file. The prefetch thread terminates.";
99+
break;
114100
}
115-
reader_->ReadNext(&buffer_[write_pos_]);
116-
++write_pos_;
117-
if (write_pos_ >= kDoubleBufferSize) {
118-
write_pos_ = 0;
101+
if (!buffer_->Send(&batch)) {
102+
VLOG(5) << "WARNING: The double buffer channel has been closed. The "
103+
"prefetch thread terminates.";
104+
break;
119105
}
120-
buffer_not_empty_.notify_all();
121106
}
122107
}
123108

0 commit comments

Comments
 (0)