@@ -26,7 +26,11 @@ class MultiFileReader : public framework::ReaderBase {
26
26
MultiFileReader (const std::vector<std::string>& file_names,
27
27
const std::vector<framework::DDim>& dims, size_t thread_num,
28
28
size_t buffer_size)
29
- : file_names_(file_names), dims_(dims), buffer_size_(buffer_size) {
29
+ : buffer_size_(buffer_size) {
30
+ readers_.reserve (file_names.size ());
31
+ for (const std::string& f_name : file_names) {
32
+ readers_.emplace_back (CreateReaderByFileName (f_name, dims));
33
+ }
30
34
prefetchers_.resize (thread_num);
31
35
StartNewScheduler ();
32
36
}
@@ -40,14 +44,13 @@ class MultiFileReader : public framework::ReaderBase {
40
44
void StartNewScheduler ();
41
45
void EndScheduler ();
42
46
void ScheduleThreadFunc ();
43
- void PrefetchThreadFunc (std::string file_name , size_t thread_idx);
47
+ void PrefetchThreadFunc (size_t reader_idx , size_t thread_idx);
44
48
45
- std::vector<std::string> file_names_;
46
- std::vector<framework::DDim> dims_;
49
+ std::vector<std::unique_ptr<framework::ReaderBase>> readers_;
47
50
std::thread scheduler_;
48
51
std::vector<std::thread> prefetchers_;
49
52
size_t buffer_size_;
50
- reader::BlockingQueue<size_t >* waiting_file_idx_ ;
53
+ reader::BlockingQueue<size_t >* waiting_reader_idx_ ;
51
54
reader::BlockingQueue<size_t >* available_thread_idx_;
52
55
reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
53
56
};
@@ -65,15 +68,15 @@ void MultiFileReader::ReInit() {
65
68
66
69
void MultiFileReader::StartNewScheduler () {
67
70
size_t thread_num = prefetchers_.size ();
68
- waiting_file_idx_ = new reader::BlockingQueue<size_t >(file_names_ .size ());
71
+ waiting_reader_idx_ = new reader::BlockingQueue<size_t >(readers_ .size ());
69
72
available_thread_idx_ = new reader::BlockingQueue<size_t >(thread_num);
70
73
buffer_ = new reader::BlockingQueue<std::vector<framework::LoDTensor>>(
71
74
buffer_size_);
72
75
73
- for (size_t i = 0 ; i < file_names_ .size (); ++i) {
74
- waiting_file_idx_ ->Send (i);
76
+ for (size_t i = 0 ; i < readers_ .size (); ++i) {
77
+ waiting_reader_idx_ ->Send (i);
75
78
}
76
- waiting_file_idx_ ->Close ();
79
+ waiting_reader_idx_ ->Close ();
77
80
for (size_t i = 0 ; i < thread_num; ++i) {
78
81
available_thread_idx_->Send (i);
79
82
}
@@ -84,13 +87,13 @@ void MultiFileReader::StartNewScheduler() {
84
87
void MultiFileReader::EndScheduler () {
85
88
available_thread_idx_->Close ();
86
89
buffer_->Close ();
87
- waiting_file_idx_ ->Close ();
90
+ waiting_reader_idx_ ->Close ();
88
91
if (scheduler_.joinable ()) {
89
92
scheduler_.join ();
90
93
}
91
94
delete buffer_;
92
95
delete available_thread_idx_;
93
- delete waiting_file_idx_ ;
96
+ delete waiting_reader_idx_ ;
94
97
}
95
98
96
99
void MultiFileReader::ScheduleThreadFunc () {
@@ -102,12 +105,11 @@ void MultiFileReader::ScheduleThreadFunc() {
102
105
if (prefetcher.joinable ()) {
103
106
prefetcher.join ();
104
107
}
105
- size_t file_idx ;
106
- if (waiting_file_idx_ ->Receive (&file_idx )) {
108
+ size_t reader_idx ;
109
+ if (waiting_reader_idx_ ->Receive (&reader_idx )) {
107
110
// Still have files to read. Start a new prefetch thread.
108
- std::string file_name = file_names_[file_idx];
109
- prefetcher = std::thread ([this , file_name, thread_idx] {
110
- PrefetchThreadFunc (file_name, thread_idx);
111
+ prefetcher = std::thread ([this , reader_idx, thread_idx] {
112
+ PrefetchThreadFunc (reader_idx, thread_idx);
111
113
});
112
114
} else {
113
115
// No more file to read.
@@ -129,23 +131,22 @@ void MultiFileReader::ScheduleThreadFunc() {
129
131
VLOG (5 ) << " MultiFileReader schedule thread terminates." ;
130
132
}
131
133
132
- void MultiFileReader::PrefetchThreadFunc (std::string file_name,
133
- size_t thread_idx) {
134
- VLOG (5 ) << " The prefetch thread of file '" << file_name << " ' starts." ;
135
- std::unique_ptr<framework::ReaderBase> reader =
136
- CreateReaderByFileName (file_name, dims_);
134
+ void MultiFileReader::PrefetchThreadFunc (size_t reader_idx, size_t thread_idx) {
135
+ VLOG (5 ) << " The prefetch thread of file idx '" << reader_idx << " ' starts." ;
136
+ std::unique_ptr<framework::ReaderBase>& reader = readers_[reader_idx];
137
137
while (true ) {
138
138
std::vector<framework::LoDTensor> ins;
139
139
reader->ReadNext (&ins);
140
140
if (ins.empty ()) {
141
+ reader->ReInit ();
141
142
break ;
142
143
}
143
144
try {
144
145
buffer_->Send (std::move (ins));
145
146
} catch (paddle::platform::EnforceNotMet e) {
146
147
VLOG (5 ) << " WARNING: The buffer channel has been closed. The prefetch "
147
- " thread of file '"
148
- << file_name << " ' will terminate." ;
148
+ " thread of file idx '"
149
+ << reader_idx << " ' will terminate." ;
149
150
break ;
150
151
}
151
152
}
@@ -154,7 +155,8 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
154
155
VLOG (5 ) << " WARNING: The available_thread_idx_ channel has been closed. "
155
156
" Fail to send thread_idx." ;
156
157
}
157
- VLOG (5 ) << " The prefetch thread of file '" << file_name << " ' terminates." ;
158
+ VLOG (5 ) << " The prefetch thread of file idx '" << reader_idx
159
+ << " ' terminates." ;
158
160
}
159
161
160
162
class OpenFilesOp : public framework ::OperatorBase {
0 commit comments