@@ -21,29 +21,31 @@ namespace reader {
21
21
22
22
class MultipleReader : public framework ::ReaderBase {
23
23
public:
24
- struct Quota {};
25
-
26
24
MultipleReader (const std::vector<std::string>& file_names,
27
25
const std::vector<framework::DDim>& dims, size_t thread_num)
28
- : file_names_(file_names), dims_(dims), thread_num_(thread_num) {
29
- PADDLE_ENFORCE_GT (thread_num_, 0 );
26
+ : file_names_(file_names), dims_(dims) {
27
+ prefetchers_. resize (thread_num );
30
28
StartNewScheduler ();
31
29
}
32
30
33
31
void ReadNext (std::vector<framework::LoDTensor>* out) override ;
34
32
bool HasNext () const override ;
35
33
void ReInit () override ;
36
34
35
+ ~MultipleReader () { EndScheduler (); }
36
+
37
37
private:
38
38
void StartNewScheduler ();
39
+ void EndScheduler ();
39
40
void ScheduleThreadFunc ();
40
- void PrefetchThreadFunc (std::string file_name);
41
+ void PrefetchThreadFunc (std::string file_name, size_t thread_idx );
41
42
42
43
std::vector<std::string> file_names_;
43
44
std::vector<framework::DDim> dims_;
44
- size_t thread_num_;
45
+ std::thread scheduler_;
46
+ std::vector<std::thread> prefetchers_;
45
47
framework::Channel<size_t >* waiting_file_idx_;
46
- framework::Channel<Quota >* thread_quotas_ ;
48
+ framework::Channel<size_t >* available_thread_idx_ ;
47
49
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
48
50
mutable std::vector<framework::LoDTensor> local_buffer_;
49
51
};
@@ -65,59 +67,76 @@ bool MultipleReader::HasNext() const {
65
67
}
66
68
67
69
void MultipleReader::ReInit () {
68
- buffer_->Close ();
69
- thread_quotas_->Close ();
70
- waiting_file_idx_->Close ();
70
+ EndScheduler ();
71
71
local_buffer_.clear ();
72
-
73
72
StartNewScheduler ();
74
73
}
75
74
76
75
void MultipleReader::StartNewScheduler () {
76
+ size_t thread_num = prefetchers_.size ();
77
77
waiting_file_idx_ = framework::MakeChannel<size_t >(file_names_.size ());
78
- thread_quotas_ = framework::MakeChannel<Quota>(thread_num_ );
78
+ available_thread_idx_ = framework::MakeChannel<size_t >(thread_num );
79
79
buffer_ =
80
- framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num_ );
80
+ framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num );
81
81
82
82
for (size_t i = 0 ; i < file_names_.size (); ++i) {
83
83
waiting_file_idx_->Send (&i);
84
84
}
85
85
waiting_file_idx_->Close ();
86
- for (size_t i = 0 ; i < thread_num_; ++i) {
87
- Quota quota;
88
- thread_quotas_->Send ("a);
86
+ for (size_t i = 0 ; i < thread_num; ++i) {
87
+ available_thread_idx_->Send (&i);
89
88
}
90
89
91
- std::thread scheduler ([this ] { ScheduleThreadFunc (); });
92
- scheduler.detach ();
90
+ scheduler_ = std::thread ([this ] { ScheduleThreadFunc (); });
91
+ }
92
+
93
+ void MultipleReader::EndScheduler () {
94
+ available_thread_idx_->Close ();
95
+ buffer_->Close ();
96
+ waiting_file_idx_->Close ();
97
+ scheduler_.join ();
98
+ delete buffer_;
99
+ delete available_thread_idx_;
100
+ delete waiting_file_idx_;
93
101
}
94
102
95
103
void MultipleReader::ScheduleThreadFunc () {
96
104
VLOG (5 ) << " MultipleReader schedule thread starts." ;
97
105
size_t completed_thread_num = 0 ;
98
- Quota quota;
99
- while (thread_quotas_->Receive ("a)) {
106
+ size_t thread_idx;
107
+ while (available_thread_idx_->Receive (&thread_idx)) {
108
+ std::thread& prefetcher = prefetchers_[thread_idx];
109
+ if (prefetcher.joinable ()) {
110
+ prefetcher.join ();
111
+ }
100
112
size_t file_idx;
101
113
if (waiting_file_idx_->Receive (&file_idx)) {
102
114
// Still have files to read. Start a new prefetch thread.
103
115
std::string file_name = file_names_[file_idx];
104
- std::thread prefetcher (
105
- [ this , file_name] { PrefetchThreadFunc (file_name); } );
106
- prefetcher. detach ( );
116
+ prefetcher = std::thread ([ this , file_name, thread_idx] {
117
+ PrefetchThreadFunc (file_name, thread_idx );
118
+ } );
107
119
} else {
108
120
// No more file to read.
109
121
++completed_thread_num;
110
- if (completed_thread_num == thread_num_) {
111
- thread_quotas_->Close ();
112
- buffer_->Close ();
122
+ if (completed_thread_num == prefetchers_.size ()) {
113
123
break ;
114
124
}
115
125
}
116
126
}
127
+ // If users invoke ReInit() when scheduler is running, it will close the
128
+ // 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
129
+ // to release their resource. So a check is needed before scheduler ends.
130
+ for (auto & p : prefetchers_) {
131
+ if (p.joinable ()) {
132
+ p.join ();
133
+ }
134
+ }
117
135
VLOG (5 ) << " MultipleReader schedule thread terminates." ;
118
136
}
119
137
120
- void MultipleReader::PrefetchThreadFunc (std::string file_name) {
138
+ void MultipleReader::PrefetchThreadFunc (std::string file_name,
139
+ size_t thread_idx) {
121
140
VLOG (5 ) << " The prefetch thread of file '" << file_name << " ' starts." ;
122
141
std::unique_ptr<framework::ReaderBase> reader =
123
142
CreateReaderByFileName (file_name, dims_);
@@ -131,8 +150,10 @@ void MultipleReader::PrefetchThreadFunc(std::string file_name) {
131
150
break ;
132
151
}
133
152
}
134
- Quota quota;
135
- thread_quotas_->Send ("a);
153
+ if (!available_thread_idx_->Send (&thread_idx)) {
154
+ VLOG (5 ) << " WARNING: The available_thread_idx_ channel has been closed. "
155
+ " Fail to send thread_idx." ;
156
+ }
136
157
VLOG (5 ) << " The prefetch thread of file '" << file_name << " ' terminates." ;
137
158
}
138
159
0 commit comments