Skip to content

Commit ee4e567

Browse files
committed
Creating readers before training begining
1 parent e0a8c58 commit ee4e567

File tree

1 file changed

+28
-24
lines changed

1 file changed

+28
-24
lines changed

paddle/fluid/operators/reader/open_files_op.cc

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ class MultiFileReader : public framework::ReaderBase {
2626
MultiFileReader(const std::vector<std::string>& file_names,
2727
const std::vector<framework::DDim>& dims, size_t thread_num,
2828
size_t buffer_size)
29-
: file_names_(file_names), dims_(dims), buffer_size_(buffer_size) {
29+
: buffer_size_(buffer_size) {
30+
readers_.resize(file_names.size());
31+
for (const std::string& f_name : file_names) {
32+
readers_.emplace_back(CreateReaderByFileName(f_name, dims));
33+
}
3034
prefetchers_.resize(thread_num);
3135
StartNewScheduler();
3236
}
@@ -40,14 +44,13 @@ class MultiFileReader : public framework::ReaderBase {
4044
void StartNewScheduler();
4145
void EndScheduler();
4246
void ScheduleThreadFunc();
43-
void PrefetchThreadFunc(std::string file_name, size_t thread_idx);
47+
void PrefetchThreadFunc(size_t reader_idx, size_t thread_idx);
4448

45-
std::vector<std::string> file_names_;
46-
std::vector<framework::DDim> dims_;
49+
std::vector<std::unique_ptr<framework::ReaderBase>> readers_;
4750
std::thread scheduler_;
4851
std::vector<std::thread> prefetchers_;
4952
size_t buffer_size_;
50-
reader::BlockingQueue<size_t>* waiting_file_idx_;
53+
reader::BlockingQueue<size_t>* waiting_reader_idx_;
5154
reader::BlockingQueue<size_t>* available_thread_idx_;
5255
reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
5356
};
@@ -60,20 +63,23 @@ void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) {
6063

6164
void MultiFileReader::ReInit() {
6265
EndScheduler();
66+
for (auto& reader : readers_) {
67+
reader->ReInit();
68+
}
6369
StartNewScheduler();
6470
}
6571

6672
void MultiFileReader::StartNewScheduler() {
6773
size_t thread_num = prefetchers_.size();
68-
waiting_file_idx_ = new reader::BlockingQueue<size_t>(file_names_.size());
74+
waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size());
6975
available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num);
7076
buffer_ = new reader::BlockingQueue<std::vector<framework::LoDTensor>>(
7177
buffer_size_);
7278

73-
for (size_t i = 0; i < file_names_.size(); ++i) {
74-
waiting_file_idx_->Send(i);
79+
for (size_t i = 0; i < readers_.size(); ++i) {
80+
waiting_reader_idx_->Send(i);
7581
}
76-
waiting_file_idx_->Close();
82+
waiting_reader_idx_->Close();
7783
for (size_t i = 0; i < thread_num; ++i) {
7884
available_thread_idx_->Send(i);
7985
}
@@ -84,13 +90,13 @@ void MultiFileReader::StartNewScheduler() {
8490
void MultiFileReader::EndScheduler() {
8591
available_thread_idx_->Close();
8692
buffer_->Close();
87-
waiting_file_idx_->Close();
93+
waiting_reader_idx_->Close();
8894
if (scheduler_.joinable()) {
8995
scheduler_.join();
9096
}
9197
delete buffer_;
9298
delete available_thread_idx_;
93-
delete waiting_file_idx_;
99+
delete waiting_reader_idx_;
94100
}
95101

96102
void MultiFileReader::ScheduleThreadFunc() {
@@ -102,12 +108,11 @@ void MultiFileReader::ScheduleThreadFunc() {
102108
if (prefetcher.joinable()) {
103109
prefetcher.join();
104110
}
105-
size_t file_idx;
106-
if (waiting_file_idx_->Receive(&file_idx)) {
111+
size_t reader_idx;
112+
if (waiting_reader_idx_->Receive(&reader_idx)) {
107113
// Still have files to read. Start a new prefetch thread.
108-
std::string file_name = file_names_[file_idx];
109-
prefetcher = std::thread([this, file_name, thread_idx] {
110-
PrefetchThreadFunc(file_name, thread_idx);
114+
prefetcher = std::thread([this, reader_idx, thread_idx] {
115+
PrefetchThreadFunc(reader_idx, thread_idx);
111116
});
112117
} else {
113118
// No more file to read.
@@ -129,11 +134,9 @@ void MultiFileReader::ScheduleThreadFunc() {
129134
VLOG(5) << "MultiFileReader schedule thread terminates.";
130135
}
131136

132-
void MultiFileReader::PrefetchThreadFunc(std::string file_name,
133-
size_t thread_idx) {
134-
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
135-
std::unique_ptr<framework::ReaderBase> reader =
136-
CreateReaderByFileName(file_name, dims_);
137+
void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) {
138+
VLOG(5) << "The prefetch thread of file idx '" << reader_idx << "' starts.";
139+
std::unique_ptr<framework::ReaderBase>& reader = readers_[reader_idx];
137140
while (true) {
138141
std::vector<framework::LoDTensor> ins;
139142
reader->ReadNext(&ins);
@@ -144,8 +147,8 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
144147
buffer_->Send(std::move(ins));
145148
} catch (paddle::platform::EnforceNotMet e) {
146149
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
147-
"thread of file '"
148-
<< file_name << "' will terminate.";
150+
"thread of file idx '"
151+
<< reader_idx << "' will terminate.";
149152
break;
150153
}
151154
}
@@ -154,7 +157,8 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
154157
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
155158
"Fail to send thread_idx.";
156159
}
157-
VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates.";
160+
VLOG(5) << "The prefetch thread of file idx '" << reader_idx
161+
<< "' terminates.";
158162
}
159163

160164
class OpenFilesOp : public framework::OperatorBase {

0 commit comments

Comments
 (0)