Skip to content

Commit f2c0b88

Browse files
authored
Merge pull request #9550 from JiayiFeng/make_MultipleReader_thread-safe
Make MultipleReader thread-safe
2 parents 232b6fc + 2945a98 commit f2c0b88

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

paddle/fluid/operators/reader/open_files_op.cc

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,22 @@ namespace reader {
2121

2222
class MultipleReader : public framework::ReaderBase {
2323
public:
24+
class ThreadBufferMap {
25+
public:
26+
std::vector<framework::LoDTensor>& operator[](
27+
const std::thread::id& thread_id) {
28+
std::lock_guard<std::mutex> lock(mutex_);
29+
return buffer_[thread_id];
30+
}
31+
32+
void Clear() { buffer_.clear(); }
33+
34+
private:
35+
std::mutex mutex_;
36+
std::unordered_map<std::thread::id, std::vector<framework::LoDTensor>>
37+
buffer_;
38+
};
39+
2440
MultipleReader(const std::vector<std::string>& file_names,
2541
const std::vector<framework::DDim>& dims, size_t thread_num)
2642
: file_names_(file_names), dims_(dims) {
@@ -47,28 +63,27 @@ class MultipleReader : public framework::ReaderBase {
4763
framework::Channel<size_t>* waiting_file_idx_;
4864
framework::Channel<size_t>* available_thread_idx_;
4965
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
50-
mutable std::vector<framework::LoDTensor> local_buffer_;
66+
mutable ThreadBufferMap thread_buffer_map_;
5167
};
5268

5369
void MultipleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
5470
if (!HasNext()) {
5571
PADDLE_THROW("There is no next data!");
5672
}
57-
58-
if (local_buffer_.empty()) {
59-
buffer_->Receive(&local_buffer_);
60-
}
61-
*out = local_buffer_;
62-
local_buffer_.clear();
73+
auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()];
74+
*out = thread_local_buffer;
75+
thread_local_buffer.clear();
6376
}
6477

6578
bool MultipleReader::HasNext() const {
66-
return local_buffer_.empty() ? buffer_->Receive(&local_buffer_) : true;
79+
auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()];
80+
return thread_local_buffer.empty() ? buffer_->Receive(&thread_local_buffer)
81+
: true;
6782
}
6883

6984
void MultipleReader::ReInit() {
7085
EndScheduler();
71-
local_buffer_.clear();
86+
thread_buffer_map_.Clear();
7287
StartNewScheduler();
7388
}
7489

@@ -176,7 +191,7 @@ class OpenFilesOp : public framework::OperatorBase {
176191
const auto& ranks = Attr<std::vector<int>>("ranks");
177192
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
178193
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
179-
int(shape_concat.size()),
194+
static_cast<int>(shape_concat.size()),
180195
"The accumulate of all ranks should be equal to the "
181196
"shape concat's length.");
182197
const auto& file_names = Attr<std::vector<std::string>>("file_names");

0 commit comments

Comments
 (0)