Skip to content

Commit 9c7fa6f

Browse files
authored
Merge pull request #10206 from JiayiFeng/blocking_queue_for_reader
Blocking queue for reader
2 parents c02ba51 + 8bd3466 commit 9c7fa6f

File tree

5 files changed

+357
-46
lines changed

5 files changed

+357
-46
lines changed

paddle/fluid/operators/reader/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,7 @@ reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_o
2323
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
2424
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
2525
reader_library(create_threaded_reader_op SRCS create_threaded_reader_op.cc)
26+
27+
cc_test(reader_blocking_queue_test SRCS reader_blocking_queue_test.cc)
2628
# Export local libraries to parent
2729
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <condition_variable> // NOLINT
18+
#include <deque>
19+
20+
#include "paddle/fluid/platform/enforce.h"
21+
22+
namespace paddle {
23+
namespace operators {
24+
namespace reader {
25+
26+
template <typename T>
27+
class BlockingQueue {
28+
// BlockingQueue is for buffered reading and is supposed to use only the
29+
// reader package. It is true that we could and we should have been using
30+
// framework::Channel, but which has currently a deadlock bug. BlockingQueue
31+
// is a workaround and a simplified version of framework::Channel as it
32+
// doesn't support GPU and it implements on buffered blocking queue.
33+
public:
34+
explicit BlockingQueue(size_t capacity)
35+
: capacity_(capacity), closed_(false) {
36+
PADDLE_ENFORCE_GT(
37+
capacity_, 0,
38+
"The capacity of a reader::BlockingQueue must be greater than 0.");
39+
}
40+
41+
bool Send(const T& elem) {
42+
std::unique_lock<std::mutex> lock(mutex_);
43+
send_cv_.wait(lock, [&] { return queue_.size() < capacity_ || closed_; });
44+
if (closed_) {
45+
VLOG(5)
46+
<< "WARNING: Sending an element to a closed reader::BlokcingQueue.";
47+
return false;
48+
}
49+
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
50+
queue_.push_back(elem);
51+
receive_cv_.notify_one();
52+
return true;
53+
}
54+
55+
bool Send(T&& elem) {
56+
std::unique_lock<std::mutex> lock(mutex_);
57+
send_cv_.wait(lock, [&] { return queue_.size() < capacity_ || closed_; });
58+
if (closed_) {
59+
VLOG(5)
60+
<< "WARNING: Sending an element to a closed reader::BlokcingQueue.";
61+
return false;
62+
}
63+
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
64+
queue_.emplace_back(std::move(elem));
65+
receive_cv_.notify_one();
66+
return true;
67+
}
68+
69+
bool Receive(T* elem) {
70+
std::unique_lock<std::mutex> lock(mutex_);
71+
receive_cv_.wait(lock, [&] { return !queue_.empty() || closed_; });
72+
if (!queue_.empty()) {
73+
PADDLE_ENFORCE_NOT_NULL(elem);
74+
*elem = queue_.front();
75+
queue_.pop_front();
76+
send_cv_.notify_one();
77+
return true;
78+
} else {
79+
PADDLE_ENFORCE(closed_);
80+
return false;
81+
}
82+
}
83+
84+
void Close() {
85+
std::lock_guard<std::mutex> lock(mutex_);
86+
closed_ = true;
87+
send_cv_.notify_all();
88+
receive_cv_.notify_all();
89+
}
90+
91+
bool IsClosed() {
92+
std::lock_guard<std::mutex> lock(mutex_);
93+
return closed_;
94+
}
95+
96+
size_t Cap() {
97+
std::lock_guard<std::mutex> lock(mutex_);
98+
return capacity_;
99+
}
100+
101+
private:
102+
size_t capacity_;
103+
bool closed_;
104+
std::deque<T> queue_;
105+
106+
std::mutex mutex_;
107+
std::condition_variable receive_cv_;
108+
std::condition_variable send_cv_;
109+
};
110+
} // namespace reader
111+
} // namespace operators
112+
} // namespace paddle

paddle/fluid/operators/reader/create_double_buffer_reader_op.cc

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
#include <thread> // NOLINT
1616

17-
#include "paddle/fluid/framework/channel.h"
17+
#include "paddle/fluid/operators/reader/blocking_queue.h"
1818
#include "paddle/fluid/operators/reader/reader_op_registry.h"
1919

2020
namespace paddle {
@@ -23,13 +23,13 @@ namespace reader {
2323

2424
// 'Double buffer' means we shall maintain two batches of input data at the same
2525
// time. So the kCacheSize shoul be at least 2.
26-
static constexpr size_t kCacheSize = 2;
26+
static constexpr size_t kCacheSize = 3;
2727
// There will be two bacthes out of the channel during training:
2828
// 1. the one waiting to be sent to the channel
2929
// 2. the one just be received from the channel, which is also being used by
3030
// subsequent operators.
3131
// So the channel size should be kChacheSize - 2
32-
static constexpr size_t kChannelSize = 0; // kCacheSize - 2
32+
static constexpr size_t kChannelSize = 1; // kCacheSize - 2
3333

3434
class DoubleBufferReader : public framework::DecoratedReader {
3535
public:
@@ -55,10 +55,8 @@ class DoubleBufferReader : public framework::DecoratedReader {
5555
~DoubleBufferReader() { EndPrefetcher(); }
5656

5757
private:
58-
bool HasNext() const;
59-
6058
void StartPrefetcher() {
61-
channel_ = framework::MakeChannel<size_t>(kChannelSize);
59+
channel_ = new reader::BlockingQueue<size_t>(kChannelSize);
6260
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
6361
}
6462

@@ -74,7 +72,7 @@ class DoubleBufferReader : public framework::DecoratedReader {
7472
void PrefetchThreadFunc();
7573

7674
std::thread prefetcher_;
77-
framework::Channel<size_t>* channel_;
75+
reader::BlockingQueue<size_t>* channel_;
7876
platform::Place place_;
7977
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache_;
8078
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache_;
@@ -139,17 +137,16 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
139137
};
140138

141139
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
142-
out->clear();
143-
if (HasNext()) {
144-
size_t cached_tensor_id;
145-
channel_->Receive(&cached_tensor_id);
140+
size_t cached_tensor_id;
141+
if (channel_->Receive(&cached_tensor_id)) {
146142
if (platform::is_gpu_place(place_)) {
147143
*out = gpu_tensor_cache_[cached_tensor_id];
148-
ctxs_[cached_tensor_id]->Wait();
149144
} else {
150145
// CPU place
151146
*out = cpu_tensor_cache_[cached_tensor_id];
152147
}
148+
} else {
149+
out->clear();
153150
}
154151
}
155152

@@ -159,12 +156,6 @@ void DoubleBufferReader::ReInit() {
159156
StartPrefetcher();
160157
}
161158

162-
bool DoubleBufferReader::HasNext() const {
163-
while (!channel_->IsClosed() && !channel_->CanReceive()) {
164-
}
165-
return channel_->CanReceive();
166-
}
167-
168159
void DoubleBufferReader::PrefetchThreadFunc() {
169160
VLOG(5) << "A new prefetch thread starts.";
170161
size_t cached_tensor_id = 0;
@@ -185,10 +176,7 @@ void DoubleBufferReader::PrefetchThreadFunc() {
185176
gpu_batch[i].set_lod(cpu_batch[i].lod());
186177
}
187178
}
188-
try {
189-
size_t tmp = cached_tensor_id;
190-
channel_->Send(&tmp);
191-
} catch (paddle::platform::EnforceNotMet e) {
179+
if (!channel_->Send(cached_tensor_id)) {
192180
VLOG(5) << "WARNING: The double buffer channel has been closed. The "
193181
"prefetch thread will terminate.";
194182
break;

paddle/fluid/operators/reader/open_files_op.cc

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
#include <thread> // NOLINT
1616

17-
#include "paddle/fluid/framework/channel.h"
17+
#include "paddle/fluid/operators/reader/blocking_queue.h"
1818
#include "paddle/fluid/operators/reader/reader_op_registry.h"
1919

2020
namespace paddle {
@@ -37,7 +37,6 @@ class MultiFileReader : public framework::ReaderBase {
3737
~MultiFileReader() { EndScheduler(); }
3838

3939
private:
40-
bool HasNext();
4140
void StartNewScheduler();
4241
void EndScheduler();
4342
void ScheduleThreadFunc();
@@ -48,15 +47,14 @@ class MultiFileReader : public framework::ReaderBase {
4847
std::thread scheduler_;
4948
std::vector<std::thread> prefetchers_;
5049
size_t buffer_size_;
51-
framework::Channel<size_t>* waiting_file_idx_;
52-
framework::Channel<size_t>* available_thread_idx_;
53-
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
50+
reader::BlockingQueue<size_t>* waiting_file_idx_;
51+
reader::BlockingQueue<size_t>* available_thread_idx_;
52+
reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
5453
};
5554

5655
void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) {
57-
out->clear();
58-
if (HasNext()) {
59-
buffer_->Receive(out);
56+
if (!buffer_->Receive(out)) {
57+
out->clear();
6058
}
6159
}
6260

@@ -65,25 +63,19 @@ void MultiFileReader::ReInit() {
6563
StartNewScheduler();
6664
}
6765

68-
bool MultiFileReader::HasNext() {
69-
while (!buffer_->IsClosed() && !buffer_->CanReceive()) {
70-
}
71-
return buffer_->CanReceive();
72-
}
73-
7466
void MultiFileReader::StartNewScheduler() {
7567
size_t thread_num = prefetchers_.size();
76-
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
77-
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
78-
buffer_ =
79-
framework::MakeChannel<std::vector<framework::LoDTensor>>(buffer_size_);
68+
waiting_file_idx_ = new reader::BlockingQueue<size_t>(file_names_.size());
69+
available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num);
70+
buffer_ = new reader::BlockingQueue<std::vector<framework::LoDTensor>>(
71+
buffer_size_);
8072

8173
for (size_t i = 0; i < file_names_.size(); ++i) {
82-
waiting_file_idx_->Send(&i);
74+
waiting_file_idx_->Send(i);
8375
}
8476
waiting_file_idx_->Close();
8577
for (size_t i = 0; i < thread_num; ++i) {
86-
available_thread_idx_->Send(&i);
78+
available_thread_idx_->Send(i);
8779
}
8880

8981
scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
@@ -149,7 +141,7 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
149141
break;
150142
}
151143
try {
152-
buffer_->Send(&ins);
144+
buffer_->Send(std::move(ins));
153145
} catch (paddle::platform::EnforceNotMet e) {
154146
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
155147
"thread of file '"
@@ -158,9 +150,7 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
158150
}
159151
}
160152

161-
try {
162-
available_thread_idx_->Send(&thread_idx);
163-
} catch (paddle::platform::EnforceNotMet e) {
153+
if (!available_thread_idx_->Send(thread_idx)) {
164154
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
165155
"Fail to send thread_idx.";
166156
}

0 commit comments

Comments
 (0)