Skip to content

Commit 87a5590

Browse files
authored
Merge pull request #11151 from JiayiFeng/dev_update_open_files_op
Update open files op
2 parents 2a5cb2e + 3526ac1 commit 87a5590

File tree

1 file changed

+26
-24
lines changed

1 file changed

+26
-24
lines changed

paddle/fluid/operators/reader/open_files_op.cc

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ class MultiFileReader : public framework::ReaderBase {
2626
MultiFileReader(const std::vector<std::string>& file_names,
2727
const std::vector<framework::DDim>& dims, size_t thread_num,
2828
size_t buffer_size)
29-
: file_names_(file_names), dims_(dims), buffer_size_(buffer_size) {
29+
: buffer_size_(buffer_size) {
30+
readers_.reserve(file_names.size());
31+
for (const std::string& f_name : file_names) {
32+
readers_.emplace_back(CreateReaderByFileName(f_name, dims));
33+
}
3034
prefetchers_.resize(thread_num);
3135
StartNewScheduler();
3236
}
@@ -40,14 +44,13 @@ class MultiFileReader : public framework::ReaderBase {
4044
void StartNewScheduler();
4145
void EndScheduler();
4246
void ScheduleThreadFunc();
43-
void PrefetchThreadFunc(std::string file_name, size_t thread_idx);
47+
void PrefetchThreadFunc(size_t reader_idx, size_t thread_idx);
4448

45-
std::vector<std::string> file_names_;
46-
std::vector<framework::DDim> dims_;
49+
std::vector<std::unique_ptr<framework::ReaderBase>> readers_;
4750
std::thread scheduler_;
4851
std::vector<std::thread> prefetchers_;
4952
size_t buffer_size_;
50-
reader::BlockingQueue<size_t>* waiting_file_idx_;
53+
reader::BlockingQueue<size_t>* waiting_reader_idx_;
5154
reader::BlockingQueue<size_t>* available_thread_idx_;
5255
reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
5356
};
@@ -65,15 +68,15 @@ void MultiFileReader::ReInit() {
6568

6669
void MultiFileReader::StartNewScheduler() {
6770
size_t thread_num = prefetchers_.size();
68-
waiting_file_idx_ = new reader::BlockingQueue<size_t>(file_names_.size());
71+
waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size());
6972
available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num);
7073
buffer_ = new reader::BlockingQueue<std::vector<framework::LoDTensor>>(
7174
buffer_size_);
7275

73-
for (size_t i = 0; i < file_names_.size(); ++i) {
74-
waiting_file_idx_->Send(i);
76+
for (size_t i = 0; i < readers_.size(); ++i) {
77+
waiting_reader_idx_->Send(i);
7578
}
76-
waiting_file_idx_->Close();
79+
waiting_reader_idx_->Close();
7780
for (size_t i = 0; i < thread_num; ++i) {
7881
available_thread_idx_->Send(i);
7982
}
@@ -84,13 +87,13 @@ void MultiFileReader::StartNewScheduler() {
8487
void MultiFileReader::EndScheduler() {
8588
available_thread_idx_->Close();
8689
buffer_->Close();
87-
waiting_file_idx_->Close();
90+
waiting_reader_idx_->Close();
8891
if (scheduler_.joinable()) {
8992
scheduler_.join();
9093
}
9194
delete buffer_;
9295
delete available_thread_idx_;
93-
delete waiting_file_idx_;
96+
delete waiting_reader_idx_;
9497
}
9598

9699
void MultiFileReader::ScheduleThreadFunc() {
@@ -102,12 +105,11 @@ void MultiFileReader::ScheduleThreadFunc() {
102105
if (prefetcher.joinable()) {
103106
prefetcher.join();
104107
}
105-
size_t file_idx;
106-
if (waiting_file_idx_->Receive(&file_idx)) {
108+
size_t reader_idx;
109+
if (waiting_reader_idx_->Receive(&reader_idx)) {
107110
// Still have files to read. Start a new prefetch thread.
108-
std::string file_name = file_names_[file_idx];
109-
prefetcher = std::thread([this, file_name, thread_idx] {
110-
PrefetchThreadFunc(file_name, thread_idx);
111+
prefetcher = std::thread([this, reader_idx, thread_idx] {
112+
PrefetchThreadFunc(reader_idx, thread_idx);
111113
});
112114
} else {
113115
// No more file to read.
@@ -129,23 +131,22 @@ void MultiFileReader::ScheduleThreadFunc() {
129131
VLOG(5) << "MultiFileReader schedule thread terminates.";
130132
}
131133

132-
void MultiFileReader::PrefetchThreadFunc(std::string file_name,
133-
size_t thread_idx) {
134-
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
135-
std::unique_ptr<framework::ReaderBase> reader =
136-
CreateReaderByFileName(file_name, dims_);
134+
void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) {
135+
VLOG(5) << "The prefetch thread of file idx '" << reader_idx << "' starts.";
136+
std::unique_ptr<framework::ReaderBase>& reader = readers_[reader_idx];
137137
while (true) {
138138
std::vector<framework::LoDTensor> ins;
139139
reader->ReadNext(&ins);
140140
if (ins.empty()) {
141+
reader->ReInit();
141142
break;
142143
}
143144
try {
144145
buffer_->Send(std::move(ins));
145146
} catch (paddle::platform::EnforceNotMet e) {
146147
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
147-
"thread of file '"
148-
<< file_name << "' will terminate.";
148+
"thread of file idx '"
149+
<< reader_idx << "' will terminate.";
149150
break;
150151
}
151152
}
@@ -154,7 +155,8 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
154155
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
155156
"Fail to send thread_idx.";
156157
}
157-
VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates.";
158+
VLOG(5) << "The prefetch thread of file idx '" << reader_idx
159+
<< "' terminates.";
158160
}
159161

160162
class OpenFilesOp : public framework::OperatorBase {

0 commit comments

Comments
 (0)