File tree Expand file tree Collapse file tree 2 files changed +66
-0
lines changed Expand file tree Collapse file tree 2 files changed +66
-0
lines changed Original file line number Diff line number Diff line change @@ -112,5 +112,46 @@ void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
112
112
out->push_back (out_tensor);
113
113
}
114
114
}
115
+
116
+ void DoubleBufferReader::ReadNext (std::vector<LoDTensor>* out) {
117
+ std::unique_lock<std::mutex> lck (mtx_);
118
+ while (write_pos_ == read_pos_) {
119
+ buffer_not_empty_.wait (lck);
120
+ }
121
+
122
+ out->clear ();
123
+ out->resize (buffer_[read_pos_].size ());
124
+ // TODO(fengjiayi): This copy shall be reduced.
125
+ for (size_t i = 0 ; i < buffer_[read_pos_].size (); ++i) {
126
+ TensorCopy (buffer_[read_pos_][i], platform::CPUPlace (), &out[i]);
127
+ out[i].set_lod (buffer_[read_pos_][i].lod ());
128
+ }
129
+
130
+ ++read_pos_;
131
+ if (read_pos_ >= kDoubleBufferSize ) {
132
+ read_pos_ = 0 ;
133
+ }
134
+ buffer_not_full_.notify_all ();
135
+ }
136
+
137
+ bool DoubleBufferReader::HasNext () {
138
+ return reader_->HasNext () || !buffer_.empty ();
139
+ }
140
+
141
+ void DoubleBufferReader::ProducerThreadFunc () {
142
+ while (reader_->HasNext ()) {
143
+ std::unique_lock<std::mutex> lck (mtx);
144
+ while (((write_pos_ + 1 ) % kDoubleBufferSize ) == read_pos_) {
145
+ buffer_not_full_.wait (lck);
146
+ }
147
+ reader_->ReadNext (&buffer_[write_pos_]);
148
+ ++write_pos_;
149
+ if (write_pos_ >= kDoubleBufferSize ) {
150
+ write_pos_ = 0 ;
151
+ }
152
+ buffer_not_empty_.notify_all ();
153
+ }
154
+ }
155
+
115
156
} // namespace framework
116
157
} // namespace paddle
Original file line number Diff line number Diff line change 16
16
17
17
#include " paddle/fluid/framework/ddim.h"
18
18
#include " paddle/fluid/framework/lod_tensor_array.h"
19
+ #include " paddle/fluid/framework/threadpool.h"
19
20
20
21
namespace paddle {
21
22
namespace framework {
22
23
24
+ static constexpr size_t kDoubleBufferSize = 3 ;
25
+
23
26
class ReaderBase {
24
27
public:
25
28
explicit ReaderBase (const std::vector<DDim>& shapes) : shapes_(shapes) {
@@ -135,6 +138,28 @@ class BatchReader : public DecoratedReader {
135
138
std::vector<std::vector<LoDTensor>> buffer_;
136
139
};
137
140
141
+ class DoubleBufferReader : public DecoratedReader {
142
+ public:
143
+ DoubleBufferReader (ReaderBase* reader)
144
+ : DecoratedReader(reader), buffer_(kDoubleBufferSize ) {
145
+ framework::Async (std::bind (&DoubleBufferReader::ProducerThreadFunc, this ));
146
+ }
147
+
148
+ void ReadNext (std::vector<LoDTensor>* out) override ;
149
+ bool HasNext () const override ;
150
+
151
+ private:
152
+ void ProducerThreadFunc ();
153
+
154
+ std::vector<std::vector<LoDTensor>> buffer_;
155
+ size_t write_pos_;
156
+ size_t read_pos_;
157
+
158
+ std::mutex mtx_;
159
+ std::condition_variable buffer_not_full_;
160
+ std::condition_variable buffer_not_empty_;
161
+ };
162
+
138
163
// The ReaderHolder is used as readers' unified wrapper,
139
164
// making it easier to access different type readers in Variables.
140
165
class ReaderHolder {
You can’t perform that action at this time.
0 commit comments