Skip to content

Commit 7c041e4

Browse files
authored
Merge pull request #9182 from JiayiFeng/dev_MultipleReader
Multi-threaded reader in C++
2 parents e4bd63d + 2532b92 commit 7c041e4

File tree

8 files changed

+348
-4
lines changed

8 files changed

+348
-4
lines changed

paddle/fluid/operators/reader/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ function(reader_library TARGET_NAME)
1515
PARENT_SCOPE)
1616
endfunction()
1717

18+
reader_library(open_files_op SRCS open_files_op.cc)
1819
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
1920
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
2021
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)

paddle/fluid/operators/reader/create_double_buffer_reader_op.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,13 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
124124
};
125125

126126
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
127+
if (!HasNext()) {
128+
PADDLE_THROW("There is no next data!");
129+
}
130+
127131
if (local_buffer_.payloads_.empty()) {
128132
buffer_->Receive(&local_buffer_);
129133
}
130-
131134
*out = local_buffer_.payloads_;
132135
local_buffer_.payloads_.clear();
133136
if (local_buffer_.ctx_) {
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/channel.h"
16+
#include "paddle/fluid/operators/reader/reader_op_registry.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
namespace reader {
21+
22+
class MultipleReader : public framework::ReaderBase {
23+
public:
24+
MultipleReader(const std::vector<std::string>& file_names,
25+
const std::vector<framework::DDim>& dims, size_t thread_num)
26+
: file_names_(file_names), dims_(dims) {
27+
prefetchers_.resize(thread_num);
28+
StartNewScheduler();
29+
}
30+
31+
void ReadNext(std::vector<framework::LoDTensor>* out) override;
32+
bool HasNext() const override;
33+
void ReInit() override;
34+
35+
~MultipleReader() { EndScheduler(); }
36+
37+
private:
38+
void StartNewScheduler();
39+
void EndScheduler();
40+
void ScheduleThreadFunc();
41+
void PrefetchThreadFunc(std::string file_name, size_t thread_idx);
42+
43+
std::vector<std::string> file_names_;
44+
std::vector<framework::DDim> dims_;
45+
std::thread scheduler_;
46+
std::vector<std::thread> prefetchers_;
47+
framework::Channel<size_t>* waiting_file_idx_;
48+
framework::Channel<size_t>* available_thread_idx_;
49+
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
50+
mutable std::vector<framework::LoDTensor> local_buffer_;
51+
};
52+
53+
void MultipleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
54+
if (!HasNext()) {
55+
PADDLE_THROW("There is no next data!");
56+
}
57+
58+
if (local_buffer_.empty()) {
59+
buffer_->Receive(&local_buffer_);
60+
}
61+
*out = local_buffer_;
62+
local_buffer_.clear();
63+
}
64+
65+
bool MultipleReader::HasNext() const {
66+
return local_buffer_.empty() ? buffer_->Receive(&local_buffer_) : true;
67+
}
68+
69+
void MultipleReader::ReInit() {
70+
EndScheduler();
71+
local_buffer_.clear();
72+
StartNewScheduler();
73+
}
74+
75+
void MultipleReader::StartNewScheduler() {
76+
size_t thread_num = prefetchers_.size();
77+
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
78+
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
79+
buffer_ =
80+
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num);
81+
82+
for (size_t i = 0; i < file_names_.size(); ++i) {
83+
waiting_file_idx_->Send(&i);
84+
}
85+
waiting_file_idx_->Close();
86+
for (size_t i = 0; i < thread_num; ++i) {
87+
available_thread_idx_->Send(&i);
88+
}
89+
90+
scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
91+
}
92+
93+
void MultipleReader::EndScheduler() {
94+
available_thread_idx_->Close();
95+
buffer_->Close();
96+
waiting_file_idx_->Close();
97+
if (scheduler_.joinable()) {
98+
scheduler_.join();
99+
}
100+
delete buffer_;
101+
delete available_thread_idx_;
102+
delete waiting_file_idx_;
103+
}
104+
105+
void MultipleReader::ScheduleThreadFunc() {
106+
VLOG(5) << "MultipleReader schedule thread starts.";
107+
size_t completed_thread_num = 0;
108+
size_t thread_idx;
109+
while (available_thread_idx_->Receive(&thread_idx)) {
110+
std::thread& prefetcher = prefetchers_[thread_idx];
111+
if (prefetcher.joinable()) {
112+
prefetcher.join();
113+
}
114+
size_t file_idx;
115+
if (waiting_file_idx_->Receive(&file_idx)) {
116+
// Still have files to read. Start a new prefetch thread.
117+
std::string file_name = file_names_[file_idx];
118+
prefetcher = std::thread([this, file_name, thread_idx] {
119+
PrefetchThreadFunc(file_name, thread_idx);
120+
});
121+
} else {
122+
// No more file to read.
123+
++completed_thread_num;
124+
if (completed_thread_num == prefetchers_.size()) {
125+
buffer_->Close();
126+
break;
127+
}
128+
}
129+
}
130+
// If users invoke ReInit() when scheduler is running, it will close the
131+
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
132+
// to release their resource. So a check is needed before scheduler ends.
133+
for (auto& p : prefetchers_) {
134+
if (p.joinable()) {
135+
p.join();
136+
}
137+
}
138+
VLOG(5) << "MultipleReader schedule thread terminates.";
139+
}
140+
141+
void MultipleReader::PrefetchThreadFunc(std::string file_name,
142+
size_t thread_idx) {
143+
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
144+
std::unique_ptr<framework::ReaderBase> reader =
145+
CreateReaderByFileName(file_name, dims_);
146+
while (reader->HasNext()) {
147+
std::vector<framework::LoDTensor> ins;
148+
reader->ReadNext(&ins);
149+
if (!buffer_->Send(&ins)) {
150+
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
151+
"thread of file '"
152+
<< file_name << "' will terminate.";
153+
break;
154+
}
155+
}
156+
if (!available_thread_idx_->Send(&thread_idx)) {
157+
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
158+
"Fail to send thread_idx.";
159+
}
160+
VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates.";
161+
}
162+
163+
class OpenFilesOp : public framework::OperatorBase {
164+
public:
165+
using framework::OperatorBase::OperatorBase;
166+
167+
private:
168+
void RunImpl(const framework::Scope& scope,
169+
const platform::Place& dev_place) const override {
170+
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
171+
const auto& ranks = Attr<std::vector<int>>("ranks");
172+
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
173+
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
174+
int(shape_concat.size()),
175+
"The accumulate of all ranks should be equal to the "
176+
"shape concat's length.");
177+
const auto& file_names = Attr<std::vector<std::string>>("file_names");
178+
PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
179+
const size_t thread_num = Attr<int>("thread_num");
180+
181+
auto* out = scope.FindVar(Output("Out"))
182+
->template GetMutable<framework::ReaderHolder>();
183+
out->Reset(new MultipleReader(
184+
file_names, RestoreShapes(shape_concat, ranks), thread_num));
185+
}
186+
};
187+
188+
class OpenFilesOpMaker : public FileReaderMakerBase {
189+
public:
190+
OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
191+
: FileReaderMakerBase(op_proto, op_checker) {
192+
AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
193+
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
194+
.GreaterThan(0);
195+
196+
AddComment(R"DOC(
197+
OpenFiles Operator
198+
199+
An OpenFilesOp creates a MultipleReader, which is able to
200+
read data multi-threaded from multiple files.
201+
)DOC");
202+
}
203+
};
204+
205+
} // namespace reader
206+
} // namespace operators
207+
} // namespace paddle
208+
209+
namespace reader = paddle::operators::reader;
210+
211+
REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp,
212+
reader::OpenFilesOpMaker);

paddle/fluid/operators/reader/reader_op_registry.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,21 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
3636
return regs;
3737
}
3838

39+
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
40+
const std::string& file_name, const std::vector<framework::DDim>& dims) {
41+
size_t separator_pos = file_name.find_last_of(kFileFormatSeparator);
42+
PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
43+
"File name illegal! A legal file name should be like: "
44+
"[file_name].[file_format] (e.g., 'data_file.recordio').");
45+
std::string filetype = file_name.substr(separator_pos + 1);
46+
47+
auto itor = FileReaderRegistry().find(filetype);
48+
PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
49+
"No file reader registered for '%s' format.", filetype);
50+
framework::ReaderBase* reader = (itor->second)(file_name, dims);
51+
return std::unique_ptr<framework::ReaderBase>(reader);
52+
}
53+
3954
FileReaderMakerBase::FileReaderMakerBase(
4055
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
4156
framework::OpAttrChecker* op_checker)

paddle/fluid/operators/reader/reader_op_registry.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ namespace paddle {
2121
namespace operators {
2222
namespace reader {
2323

24+
static constexpr char kFileFormatSeparator[] = ".";
25+
2426
using FileReaderCreator = std::function<framework::ReaderBase*(
2527
const std::string&, const std::vector<framework::DDim>&)>;
2628

@@ -29,12 +31,15 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry();
2931
template <typename Reader>
3032
int RegisterFileReader(const std::string& filetype) {
3133
FileReaderRegistry()[filetype] = [](
32-
const std::string& fn, const std::vector<paddle::framework::DDim>& dim) {
33-
return new Reader(fn, dim);
34+
const std::string& fn, const std::vector<framework::DDim>& dims) {
35+
return new Reader(fn, dims);
3436
};
3537
return 0;
3638
}
3739

40+
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
41+
const std::string& file_name, const std::vector<framework::DDim>& dims);
42+
3843
extern std::vector<framework::DDim> RestoreShapes(
3944
const std::vector<int>& shape_concat, const std::vector<int>& ranks);
4045

python/paddle/fluid/layers/io.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
__all__ = [
2323
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
24-
'read_file', 'create_shuffle_reader', 'create_double_buffer_reader'
24+
'open_files', 'read_file', 'create_shuffle_reader',
25+
'create_double_buffer_reader'
2526
]
2627

2728

@@ -287,6 +288,36 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
287288
startup_var)
288289

289290

291+
def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
292+
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
293+
shape_concat = []
294+
ranks = []
295+
296+
for shape in shapes:
297+
shape_concat.extend(shape)
298+
ranks.append(len(shape))
299+
300+
var_name = unique_name('multiple_reader')
301+
302+
startup_blk = default_startup_program().current_block()
303+
startup_var = startup_blk.create_var(name=var_name)
304+
startup_blk.append_op(
305+
type='open_files',
306+
outputs={'Out': [startup_var]},
307+
attrs={
308+
'shape_concat': shape_concat,
309+
'lod_levels': lod_levels,
310+
'ranks': ranks,
311+
'file_names': filenames,
312+
'thread_num': thread_num
313+
})
314+
315+
startup_var.desc.set_dtypes(dtypes)
316+
startup_var.persistable = True
317+
return _copy_reader_var_(default_main_program().current_block(),
318+
startup_var)
319+
320+
290321
def __create_decorated_reader__(op_type, reader, attrs):
291322
var_name = unique_name(op_type)
292323
startup_blk = default_startup_program().current_block()
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
mnist.recordio
2+
mnist_0.recordio
3+
mnist_1.recordio
4+
mnist_2.recordio

0 commit comments

Comments
 (0)