Skip to content

Commit a2981f5

Browse files
committed
fix a bug
1 parent 87ac675 commit a2981f5

File tree

1 file changed

+50
-29
lines changed

1 file changed

+50
-29
lines changed

paddle/fluid/operators/reader/open_files_op.cc

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,29 +21,31 @@ namespace reader {
2121

2222
class MultipleReader : public framework::ReaderBase {
2323
public:
24-
struct Quota {};
25-
2624
MultipleReader(const std::vector<std::string>& file_names,
2725
const std::vector<framework::DDim>& dims, size_t thread_num)
28-
: file_names_(file_names), dims_(dims), thread_num_(thread_num) {
29-
PADDLE_ENFORCE_GT(thread_num_, 0);
26+
: file_names_(file_names), dims_(dims) {
27+
prefetchers_.resize(thread_num);
3028
StartNewScheduler();
3129
}
3230

3331
void ReadNext(std::vector<framework::LoDTensor>* out) override;
3432
bool HasNext() const override;
3533
void ReInit() override;
3634

35+
~MultipleReader() { EndScheduler(); }
36+
3737
private:
3838
void StartNewScheduler();
39+
void EndScheduler();
3940
void ScheduleThreadFunc();
40-
void PrefetchThreadFunc(std::string file_name);
41+
void PrefetchThreadFunc(std::string file_name, size_t thread_idx);
4142

4243
std::vector<std::string> file_names_;
4344
std::vector<framework::DDim> dims_;
44-
size_t thread_num_;
45+
std::thread scheduler_;
46+
std::vector<std::thread> prefetchers_;
4547
framework::Channel<size_t>* waiting_file_idx_;
46-
framework::Channel<Quota>* thread_quotas_;
48+
framework::Channel<size_t>* available_thread_idx_;
4749
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
4850
mutable std::vector<framework::LoDTensor> local_buffer_;
4951
};
@@ -65,59 +67,76 @@ bool MultipleReader::HasNext() const {
6567
}
6668

6769
void MultipleReader::ReInit() {
68-
buffer_->Close();
69-
thread_quotas_->Close();
70-
waiting_file_idx_->Close();
70+
EndScheduler();
7171
local_buffer_.clear();
72-
7372
StartNewScheduler();
7473
}
7574

7675
void MultipleReader::StartNewScheduler() {
76+
size_t thread_num = prefetchers_.size();
7777
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
78-
thread_quotas_ = framework::MakeChannel<Quota>(thread_num_);
78+
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
7979
buffer_ =
80-
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num_);
80+
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num);
8181

8282
for (size_t i = 0; i < file_names_.size(); ++i) {
8383
waiting_file_idx_->Send(&i);
8484
}
8585
waiting_file_idx_->Close();
86-
for (size_t i = 0; i < thread_num_; ++i) {
87-
Quota quota;
88-
thread_quotas_->Send(&quota);
86+
for (size_t i = 0; i < thread_num; ++i) {
87+
available_thread_idx_->Send(&i);
8988
}
9089

91-
std::thread scheduler([this] { ScheduleThreadFunc(); });
92-
scheduler.detach();
90+
scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
91+
}
92+
93+
void MultipleReader::EndScheduler() {
94+
available_thread_idx_->Close();
95+
buffer_->Close();
96+
waiting_file_idx_->Close();
97+
scheduler_.join();
98+
delete buffer_;
99+
delete available_thread_idx_;
100+
delete waiting_file_idx_;
93101
}
94102

95103
void MultipleReader::ScheduleThreadFunc() {
96104
VLOG(5) << "MultipleReader schedule thread starts.";
97105
size_t completed_thread_num = 0;
98-
Quota quota;
99-
while (thread_quotas_->Receive(&quota)) {
106+
size_t thread_idx;
107+
while (available_thread_idx_->Receive(&thread_idx)) {
108+
std::thread& prefetcher = prefetchers_[thread_idx];
109+
if (prefetcher.joinable()) {
110+
prefetcher.join();
111+
}
100112
size_t file_idx;
101113
if (waiting_file_idx_->Receive(&file_idx)) {
102114
// Still have files to read. Start a new prefetch thread.
103115
std::string file_name = file_names_[file_idx];
104-
std::thread prefetcher(
105-
[this, file_name] { PrefetchThreadFunc(file_name); });
106-
prefetcher.detach();
116+
prefetcher = std::thread([this, file_name, thread_idx] {
117+
PrefetchThreadFunc(file_name, thread_idx);
118+
});
107119
} else {
108120
// No more file to read.
109121
++completed_thread_num;
110-
if (completed_thread_num == thread_num_) {
111-
thread_quotas_->Close();
112-
buffer_->Close();
122+
if (completed_thread_num == prefetchers_.size()) {
113123
break;
114124
}
115125
}
116126
}
127+
// If users invoke ReInit() when scheduler is running, it will close the
128+
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
129+
// to release their resource. So a check is needed before scheduler ends.
130+
for (auto& p : prefetchers_) {
131+
if (p.joinable()) {
132+
p.join();
133+
}
134+
}
117135
VLOG(5) << "MultipleReader schedule thread terminates.";
118136
}
119137

120-
void MultipleReader::PrefetchThreadFunc(std::string file_name) {
138+
void MultipleReader::PrefetchThreadFunc(std::string file_name,
139+
size_t thread_idx) {
121140
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
122141
std::unique_ptr<framework::ReaderBase> reader =
123142
CreateReaderByFileName(file_name, dims_);
@@ -131,8 +150,10 @@ void MultipleReader::PrefetchThreadFunc(std::string file_name) {
131150
break;
132151
}
133152
}
134-
Quota quota;
135-
thread_quotas_->Send(&quota);
153+
if (!available_thread_idx_->Send(&thread_idx)) {
154+
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
155+
"Fail to send thread_idx.";
156+
}
136157
VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates.";
137158
}
138159

0 commit comments

Comments
 (0)