@@ -26,7 +26,11 @@ class MultiFileReader : public framework::ReaderBase {
26
26
MultiFileReader (const std::vector<std::string>& file_names,
27
27
const std::vector<framework::DDim>& dims, size_t thread_num,
28
28
size_t buffer_size)
29
- : file_names_(file_names), dims_(dims), buffer_size_(buffer_size) {
29
+ : buffer_size_(buffer_size) {
30
+ readers_.resize (file_names.size ());
31
+ for (const std::string& f_name : file_names) {
32
+ readers_.emplace_back (CreateReaderByFileName (f_name, dims));
33
+ }
30
34
prefetchers_.resize (thread_num);
31
35
StartNewScheduler ();
32
36
}
@@ -40,14 +44,13 @@ class MultiFileReader : public framework::ReaderBase {
40
44
void StartNewScheduler ();
41
45
void EndScheduler ();
42
46
void ScheduleThreadFunc ();
43
- void PrefetchThreadFunc (std::string file_name , size_t thread_idx);
47
+ void PrefetchThreadFunc (size_t reader_idx , size_t thread_idx);
44
48
45
- std::vector<std::string> file_names_;
46
- std::vector<framework::DDim> dims_;
49
+ std::vector<std::unique_ptr<framework::ReaderBase>> readers_;
47
50
std::thread scheduler_;
48
51
std::vector<std::thread> prefetchers_;
49
52
size_t buffer_size_;
50
- reader::BlockingQueue<size_t >* waiting_file_idx_ ;
53
+ reader::BlockingQueue<size_t >* waiting_reader_idx_ ;
51
54
reader::BlockingQueue<size_t >* available_thread_idx_;
52
55
reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
53
56
};
@@ -60,20 +63,23 @@ void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) {
60
63
61
64
void MultiFileReader::ReInit () {
62
65
EndScheduler ();
66
+ for (auto & reader : readers_) {
67
+ reader->ReInit ();
68
+ }
63
69
StartNewScheduler ();
64
70
}
65
71
66
72
void MultiFileReader::StartNewScheduler () {
67
73
size_t thread_num = prefetchers_.size ();
68
- waiting_file_idx_ = new reader::BlockingQueue<size_t >(file_names_ .size ());
74
+ waiting_reader_idx_ = new reader::BlockingQueue<size_t >(readers_ .size ());
69
75
available_thread_idx_ = new reader::BlockingQueue<size_t >(thread_num);
70
76
buffer_ = new reader::BlockingQueue<std::vector<framework::LoDTensor>>(
71
77
buffer_size_);
72
78
73
- for (size_t i = 0 ; i < file_names_ .size (); ++i) {
74
- waiting_file_idx_ ->Send (i);
79
+ for (size_t i = 0 ; i < readers_ .size (); ++i) {
80
+ waiting_reader_idx_ ->Send (i);
75
81
}
76
- waiting_file_idx_ ->Close ();
82
+ waiting_reader_idx_ ->Close ();
77
83
for (size_t i = 0 ; i < thread_num; ++i) {
78
84
available_thread_idx_->Send (i);
79
85
}
@@ -84,13 +90,13 @@ void MultiFileReader::StartNewScheduler() {
84
90
void MultiFileReader::EndScheduler () {
85
91
available_thread_idx_->Close ();
86
92
buffer_->Close ();
87
- waiting_file_idx_ ->Close ();
93
+ waiting_reader_idx_ ->Close ();
88
94
if (scheduler_.joinable ()) {
89
95
scheduler_.join ();
90
96
}
91
97
delete buffer_;
92
98
delete available_thread_idx_;
93
- delete waiting_file_idx_ ;
99
+ delete waiting_reader_idx_ ;
94
100
}
95
101
96
102
void MultiFileReader::ScheduleThreadFunc () {
@@ -102,12 +108,11 @@ void MultiFileReader::ScheduleThreadFunc() {
102
108
if (prefetcher.joinable ()) {
103
109
prefetcher.join ();
104
110
}
105
- size_t file_idx ;
106
- if (waiting_file_idx_ ->Receive (&file_idx )) {
111
+ size_t reader_idx ;
112
+ if (waiting_reader_idx_ ->Receive (&reader_idx )) {
107
113
// Still have files to read. Start a new prefetch thread.
108
- std::string file_name = file_names_[file_idx];
109
- prefetcher = std::thread ([this , file_name, thread_idx] {
110
- PrefetchThreadFunc (file_name, thread_idx);
114
+ prefetcher = std::thread ([this , reader_idx, thread_idx] {
115
+ PrefetchThreadFunc (reader_idx, thread_idx);
111
116
});
112
117
} else {
113
118
// No more file to read.
@@ -129,11 +134,9 @@ void MultiFileReader::ScheduleThreadFunc() {
129
134
VLOG (5 ) << " MultiFileReader schedule thread terminates." ;
130
135
}
131
136
132
- void MultiFileReader::PrefetchThreadFunc (std::string file_name,
133
- size_t thread_idx) {
134
- VLOG (5 ) << " The prefetch thread of file '" << file_name << " ' starts." ;
135
- std::unique_ptr<framework::ReaderBase> reader =
136
- CreateReaderByFileName (file_name, dims_);
137
+ void MultiFileReader::PrefetchThreadFunc (size_t reader_idx, size_t thread_idx) {
138
+ VLOG (5 ) << " The prefetch thread of file idx '" << reader_idx << " ' starts." ;
139
+ std::unique_ptr<framework::ReaderBase>& reader = readers_[reader_idx];
137
140
while (true ) {
138
141
std::vector<framework::LoDTensor> ins;
139
142
reader->ReadNext (&ins);
@@ -144,8 +147,8 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
144
147
buffer_->Send (std::move (ins));
145
148
} catch (paddle::platform::EnforceNotMet e) {
146
149
VLOG (5 ) << " WARNING: The buffer channel has been closed. The prefetch "
147
- " thread of file '"
148
- << file_name << " ' will terminate." ;
150
+ " thread of file idx '"
151
+ << reader_idx << " ' will terminate." ;
149
152
break ;
150
153
}
151
154
}
@@ -154,7 +157,8 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
154
157
VLOG (5 ) << " WARNING: The available_thread_idx_ channel has been closed. "
155
158
" Fail to send thread_idx." ;
156
159
}
157
- VLOG (5 ) << " The prefetch thread of file '" << file_name << " ' terminates." ;
160
+ VLOG (5 ) << " The prefetch thread of file idx '" << reader_idx
161
+ << " ' terminates." ;
158
162
}
159
163
160
164
class OpenFilesOp : public framework ::OperatorBase {
0 commit comments