Skip to content

Commit 5506225

Browse files
committed
Add MultipleReader and open_files_op
1 parent 128adf5 commit 5506225

File tree

4 files changed

+224
-3
lines changed

4 files changed

+224
-3
lines changed

paddle/fluid/operators/reader/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@ reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
2020
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
2121
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
2222
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
23+
reader_library(open_files_op SRCS open_files_op.cc)
2324
# Export local libraries to parent
2425
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)

paddle/fluid/operators/reader/create_double_buffer_reader_op.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,13 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
120120
};
121121

122122
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
123+
if (!HasNext()) {
124+
PADDLE_THROW("There is no next data!");
125+
}
126+
123127
if (local_buffer_.payloads_.empty()) {
124128
buffer_->Receive(&local_buffer_);
125129
}
126-
127130
*out = local_buffer_.payloads_;
128131
local_buffer_.payloads_.clear();
129132
if (local_buffer_.ctx_) {
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/channel.h"
16+
#include "paddle/fluid/operators/reader/reader_op_registry.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
namespace reader {
21+
22+
class MultipleReader : public framework::ReaderBase {
23+
public:
24+
struct Quota {};
25+
26+
MultipleReader(const std::vector<std::string>& file_names,
27+
const std::vector<framework::DDim>& dims, size_t thread_num)
28+
: file_names_(file_names), dims_(dims), thread_num_(thread_num) {
29+
PADDLE_ENFORCE_GT(thread_num_, 0);
30+
StartNewScheduler();
31+
}
32+
33+
void ReadNext(std::vector<framework::LoDTensor>* out) override;
34+
bool HasNext() const override;
35+
void ReInit() override;
36+
37+
private:
38+
void StartNewScheduler();
39+
void ScheduleThreadFunc();
40+
void PrefetchThreadFunc(std::string file_name);
41+
42+
std::vector<std::string> file_names_;
43+
std::vector<framework::DDim> dims_;
44+
size_t thread_num_;
45+
framework::Channel<size_t>* waiting_file_idx_;
46+
framework::Channel<Quota>* thread_quotas_;
47+
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
48+
mutable std::vector<framework::LoDTensor> local_buffer_;
49+
};
50+
51+
void MultipleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
52+
if (!HasNext()) {
53+
PADDLE_THROW("There is no next data!");
54+
}
55+
56+
if (local_buffer_.empty()) {
57+
buffer_->Receive(&local_buffer_);
58+
}
59+
*out = local_buffer_;
60+
local_buffer_.clear();
61+
}
62+
63+
bool MultipleReader::HasNext() const {
64+
return local_buffer_.empty() ? buffer_->Receive(&local_buffer_) : true;
65+
}
66+
67+
void MultipleReader::ReInit() {
68+
buffer_->Close();
69+
thread_quotas_->Close();
70+
waiting_file_idx_->Close();
71+
local_buffer_.clear();
72+
73+
StartNewScheduler();
74+
}
75+
76+
void MultipleReader::StartNewScheduler() {
77+
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
78+
thread_quotas_ = framework::MakeChannel<Quota>(thread_num_);
79+
buffer_ =
80+
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num_);
81+
82+
for (size_t i = 0; i < file_names_.size(); ++i) {
83+
waiting_file_idx_->Send(&i);
84+
}
85+
waiting_file_idx_->Close();
86+
for (size_t i = 0; i < thread_num_; ++i) {
87+
Quota quota;
88+
thread_quotas_->Send(&quota);
89+
}
90+
91+
std::thread scheduler([this] { ScheduleThreadFunc(); });
92+
scheduler.detach();
93+
}
94+
95+
void MultipleReader::ScheduleThreadFunc() {
96+
VLOG(5) << "MultipleReader schedule thread starts.";
97+
size_t completed_thread_num = 0;
98+
Quota quota;
99+
while (thread_quotas_->Receive(&quota)) {
100+
size_t file_idx;
101+
if (waiting_file_idx_->Receive(&file_idx)) {
102+
// Still have files to read. Start a new prefetch thread.
103+
std::string file_name = file_names_[file_idx];
104+
std::thread prefetcher(
105+
[this, file_name] { PrefetchThreadFunc(file_name); });
106+
prefetcher.detach();
107+
} else {
108+
// No more file to read.
109+
++completed_thread_num;
110+
if (completed_thread_num == thread_num_) {
111+
thread_quotas_->Close();
112+
buffer_->Close();
113+
break;
114+
}
115+
}
116+
}
117+
VLOG(5) << "MultipleReader schedule thread terminates.";
118+
}
119+
120+
void MultipleReader::PrefetchThreadFunc(std::string file_name) {
121+
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
122+
std::unique_ptr<framework::ReaderBase> reader =
123+
CreateReaderByFileName(file_name, dims_);
124+
while (reader->HasNext()) {
125+
std::vector<framework::LoDTensor> ins;
126+
reader->ReadNext(&ins);
127+
if (!buffer_->Send(&ins)) {
128+
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
129+
"thread of file '"
130+
<< file_name << "' will terminate.";
131+
break;
132+
}
133+
}
134+
Quota quota;
135+
thread_quotas_->Send(&quota);
136+
VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates.";
137+
}
138+
139+
class OpenFilesOp : public framework::OperatorBase {
140+
public:
141+
using framework::OperatorBase::OperatorBase;
142+
143+
private:
144+
void RunImpl(const framework::Scope& scope,
145+
const platform::Place& dev_place) const override {
146+
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
147+
const auto& ranks = Attr<std::vector<int>>("ranks");
148+
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
149+
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
150+
int(shape_concat.size()),
151+
"The accumulate of all ranks should be equal to the "
152+
"shape concat's length.");
153+
const auto& file_names = Attr<std::vector<std::string>>("file_names");
154+
PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
155+
const size_t thread_num = Attr<int>("thread_num");
156+
157+
auto* out = scope.FindVar(Output("Out"))
158+
->template GetMutable<framework::ReaderHolder>();
159+
out->Reset(new MultipleReader(
160+
file_names, RestoreShapes(shape_concat, ranks), thread_num));
161+
}
162+
};
163+
164+
class OpenFilesOpMaker : public framework::OpProtoAndCheckerMaker {
165+
public:
166+
OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
167+
: OpProtoAndCheckerMaker(op_proto, op_checker) {
168+
AddComment(R"DOC(
169+
OpenFiles Operator
170+
171+
An OpenFilesOp creates a MultipleReader, which is able to
172+
read data multi-threaded from multiple files.
173+
)DOC");
174+
AddOutput("Out", "(ReaderHolder) The created MultipleReader.");
175+
AddAttr<std::vector<int>>("shape_concat",
176+
"The concat of all data's shapes.");
177+
AddAttr<std::vector<int>>(
178+
"ranks",
179+
"The ranks of each data."
180+
"e.g."
181+
"shape_concat = [2,3,4,5,6]"
182+
"ranks = [3,2]"
183+
"It means the reader will generate two data each time,"
184+
"whose shapes are [2,3,4] and [5,6] respectively.");
185+
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
186+
AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
187+
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
188+
.GreaterThan(0);
189+
}
190+
};
191+
192+
} // namespace reader
193+
} // namespace operators
194+
} // namespace paddle
195+
196+
namespace reader = paddle::operators::reader;
197+
198+
REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp,
199+
reader::OpenFilesOpMaker);

paddle/fluid/operators/reader/reader_op_registry.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ namespace paddle {
2121
namespace operators {
2222
namespace reader {
2323

24+
static constexpr char kFileFormatSeparator[] = ":";
25+
2426
using FileReaderCreator = std::function<framework::ReaderBase*(
2527
const std::string&, const std::vector<framework::DDim>&)>;
2628

@@ -29,12 +31,28 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry();
2931
template <typename Reader>
3032
int RegisterFileReader(const std::string& filetype) {
3133
FileReaderRegistry()[filetype] = [](
32-
const std::string& fn, const std::vector<paddle::framework::DDim>& dim) {
33-
return new Reader(fn, dim);
34+
const std::string& fn, const std::vector<framework::DDim>& dims) {
35+
return new Reader(fn, dims);
3436
};
3537
return 0;
3638
}
3739

40+
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
41+
const std::string& file_name, const std::vector<framework::DDim>& dims) {
42+
size_t separator_pos = file_name.find(kFileFormatSeparator);
43+
PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
44+
"File name illegal! A legal file name should be like: "
45+
"[file_format]:[file_name] (e.g., 'recordio:data_file').");
46+
std::string filetype = file_name.substr(0, separator_pos);
47+
std::string f_name = file_name.substr(separator_pos + 1);
48+
49+
auto itor = FileReaderRegistry().find(filetype);
50+
PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
51+
"No file reader registered for '%s' format.", filetype);
52+
framework::ReaderBase* reader = (itor->second)(f_name, dims);
53+
return std::unique_ptr<framework::ReaderBase>(reader);
54+
}
55+
3856
extern std::vector<framework::DDim> RestoreShapes(
3957
const std::vector<int>& shape_concat, const std::vector<int>& ranks);
4058

0 commit comments

Comments
 (0)