|
12 | 12 | // See the License for the specific language governing permissions and
|
13 | 13 | // limitations under the License.
|
14 | 14 |
|
15 |
| -#include <condition_variable> |
16 |
| -#include <mutex> |
17 | 15 | #include <thread>
|
| 16 | +#include "paddle/fluid/framework/channel.h" |
18 | 17 | #include "paddle/fluid/operators/reader/reader_op_registry.h"
|
19 | 18 |
|
20 | 19 | namespace paddle {
|
21 | 20 | namespace operators {
|
22 | 21 | namespace reader {
|
23 | 22 |
|
24 |
| -static constexpr size_t kDoubleBufferSize = 3; |
| 23 | +static constexpr size_t kDoubleBufferSize = 2; |
25 | 24 |
|
26 | 25 | class DoubleBufferReader : public framework::DecoratedReader {
|
27 | 26 | public:
|
28 | 27 | explicit DoubleBufferReader(ReaderBase* reader)
|
29 | 28 | : DecoratedReader(reader),
|
30 |
| - buffer_(kDoubleBufferSize), |
31 |
| - write_pos_(0), |
32 |
| - read_pos_(0) { |
33 |
| - std::thread prefetch( |
34 |
| - std::bind(&DoubleBufferReader::PrefetchThreadFunc, this)); |
| 29 | + buffer_(framework::MakeChannel<std::vector<framework::LoDTensor>>( |
| 30 | + kDoubleBufferSize)) { |
| 31 | + std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this); |
35 | 32 | prefetch.detach();
|
36 | 33 | }
|
37 | 34 |
|
38 | 35 | void ReadNext(std::vector<framework::LoDTensor>* out) override;
|
39 |
| - bool HasNext() const override; |
| 36 | + void ReInit() override; |
| 37 | + |
| 38 | + ~DoubleBufferReader() { buffer_->Close(); } |
40 | 39 |
|
41 | 40 | private:
|
42 | 41 | void PrefetchThreadFunc();
|
43 | 42 |
|
44 |
| - std::vector<std::vector<framework::LoDTensor>> buffer_; |
45 |
| - size_t write_pos_; |
46 |
| - size_t read_pos_; |
47 |
| - |
48 |
| - std::mutex mtx_; |
49 |
| - std::condition_variable buffer_not_full_; |
50 |
| - std::condition_variable buffer_not_empty_; |
| 43 | + framework::Channel<std::vector<framework::LoDTensor>>* buffer_; |
51 | 44 | };
|
52 | 45 |
|
53 | 46 | class CreateDoubleBufferReaderOp : public framework::OperatorBase {
|
@@ -80,44 +73,36 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
|
80 | 73 | };
|
81 | 74 |
|
82 | 75 | void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
|
83 |
| - std::unique_lock<std::mutex> lck(mtx_); |
84 |
| - while (write_pos_ == read_pos_) { |
85 |
| - buffer_not_empty_.wait(lck); |
86 |
| - } |
87 |
| - |
88 | 76 | out->clear();
|
89 |
| - out->reserve(buffer_[read_pos_].size()); |
90 |
| - // TODO(fengjiayi): This copy shall be reduced. |
91 |
| - for (size_t i = 0; i < buffer_[read_pos_].size(); ++i) { |
92 |
| - framework::LoDTensor dst; |
93 |
| - TensorCopy(buffer_[read_pos_][i], platform::CPUPlace(), &dst); |
94 |
| - dst.set_lod(buffer_[read_pos_][i].lod()); |
95 |
| - out->push_back(dst); |
96 |
| - } |
97 |
| - |
98 |
| - ++read_pos_; |
99 |
| - if (read_pos_ >= kDoubleBufferSize) { |
100 |
| - read_pos_ = 0; |
101 |
| - } |
102 |
| - buffer_not_full_.notify_all(); |
| 77 | + buffer_->Receive(out); |
103 | 78 | }
|
104 | 79 |
|
105 |
| -bool DoubleBufferReader::HasNext() const { |
106 |
| - return reader_->HasNext() || !buffer_.empty(); |
| 80 | +void DoubleBufferReader::ReInit() { |
| 81 | + reader_->ReInit(); |
| 82 | + buffer_->Close(); |
| 83 | + // The existing prefetch thread will terminate for the buffer_ is closed. |
| 84 | + buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>( |
| 85 | + kDoubleBufferSize); |
| 86 | + std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this); |
| 87 | + prefetch.detach(); |
107 | 88 | }
|
108 | 89 |
|
109 | 90 | void DoubleBufferReader::PrefetchThreadFunc() {
|
110 |
| - while (reader_->HasNext()) { |
111 |
| - std::unique_lock<std::mutex> lck(mtx_); |
112 |
| - while (((write_pos_ + 1) % kDoubleBufferSize) == read_pos_) { |
113 |
| - buffer_not_full_.wait(lck); |
| 91 | + VLOG(5) << "A new prefetch thread starts."; |
| 92 | + while (true) { |
| 93 | + std::vector<framework::LoDTensor> batch; |
| 94 | + reader_->ReadNext(&batch); |
| 95 | + if (batch.empty()) { |
| 96 | + // EOF |
| 97 | + buffer_->Close(); |
| 98 | + VLOG(5) << "Reached the end of the file. The prefetch thread terminates."; |
| 99 | + break; |
114 | 100 | }
|
115 |
| - reader_->ReadNext(&buffer_[write_pos_]); |
116 |
| - ++write_pos_; |
117 |
| - if (write_pos_ >= kDoubleBufferSize) { |
118 |
| - write_pos_ = 0; |
| 101 | + if (!buffer_->Send(&batch)) { |
| 102 | + VLOG(5) << "WARNING: The double buffer channel has been closed. The " |
| 103 | + "prefetch thread terminates."; |
| 104 | + break; |
119 | 105 | }
|
120 |
| - buffer_not_empty_.notify_all(); |
121 | 106 | }
|
122 | 107 | }
|
123 | 108 |
|
|
0 commit comments