Skip to content

Commit 4051fb3

Browse files
committed
add monitor thread
1 parent e677833 commit 4051fb3

File tree

3 files changed

+46
-2
lines changed

3 files changed

+46
-2
lines changed

paddle/fluid/operators/reader/ctr_reader.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,26 @@ class MultiGzipReader : public Reader {
123123
size_t current_reader_index_ = 0;
124124
};
125125

126+
void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
127+
std::shared_ptr<LoDTensorBlockingQueue> queue) {
128+
VLOG(3) << "monitor thread in";
129+
bool reader_thread_is_running = true;
130+
while (reader_thread_is_running) {
131+
VLOG(3) << "reader_thread_is_running";
132+
reader_thread_is_running = false;
133+
for (size_t i = 0; i < (*thread_status).size(); ++i) {
134+
if ((*thread_status)[i] == Running) {
135+
VLOG(3) << "reader is running!";
136+
reader_thread_is_running = true;
137+
}
138+
}
139+
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
140+
}
141+
VLOG(3) << "all reader thread is stopped, push empty data into queue";
142+
queue->Push({});
143+
VLOG(3) << "monitor thread exited";
144+
}
145+
126146
void ReadThread(const std::vector<std::string>& file_list,
127147
const std::vector<std::string>& slots, int batch_size,
128148
int thread_id, std::vector<ReaderThreadStatus>* thread_status,

paddle/fluid/operators/reader/ctr_reader.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include <sys/time.h>
1818

19+
#include <chrono> // NOLINT
1920
#include <cstdlib>
2021
#include <fstream>
2122
#include <iostream>
@@ -39,6 +40,11 @@ void ReadThread(const std::vector<std::string>& file_list,
3940
int thread_id, std::vector<ReaderThreadStatus>* thread_status,
4041
std::shared_ptr<LoDTensorBlockingQueue> queue);
4142

43+
// monitor all running thread, if they are all stopped,
44+
// then push an empty data into LoDTensorBlockingQueue
45+
void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
46+
std::shared_ptr<LoDTensorBlockingQueue> queue);
47+
4248
class CTRReader : public framework::FileReader {
4349
public:
4450
explicit CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue,
@@ -58,7 +64,7 @@ class CTRReader : public framework::FileReader {
5864
}
5965
}
6066

61-
~CTRReader() { Shutdown(); }
67+
~CTRReader() {}
6268

6369
void ReadNext(std::vector<framework::LoDTensor>* out) override {
6470
bool success;
@@ -68,12 +74,19 @@ class CTRReader : public framework::FileReader {
6874

6975
void Shutdown() override {
7076
VLOG(3) << "Shutdown reader";
77+
if (status_ == ReaderStatus::kStopped) {
78+
return;
79+
}
7180
// shutdown should stop all the reader thread
7281
for (auto& read_thread : read_threads_) {
7382
read_thread->join();
7483
}
84+
monitor_thread_->join();
85+
7586
read_threads_.clear();
87+
monitor_thread_.reset(nullptr);
7688
queue_->Close();
89+
status_ = ReaderStatus::kStopped;
7790
}
7891

7992
void Start() override {
@@ -87,6 +100,9 @@ class CTRReader : public framework::FileReader {
87100
std::bind(&ReadThread, file_groups_[thread_id], slots_, batch_size_,
88101
thread_id, &read_thread_status_, queue_)));
89102
}
103+
monitor_thread_.reset(new std::thread(
104+
std::bind(&MonitorThread, &read_thread_status_, queue_)));
105+
status_ = ReaderStatus::kRunning;
90106
}
91107

92108
private:
@@ -107,6 +123,7 @@ class CTRReader : public framework::FileReader {
107123
const std::vector<std::string> file_list_;
108124
std::shared_ptr<LoDTensorBlockingQueue> queue_;
109125
std::vector<std::unique_ptr<std::thread>> read_threads_;
126+
std::unique_ptr<std::thread> monitor_thread_;
110127
std::vector<ReaderThreadStatus> read_thread_status_;
111128
std::vector<std::vector<std::string>> file_groups_;
112129
};

paddle/fluid/operators/reader/ctr_reader_test.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ TEST(CTR_READER, read_data) {
107107
size_t batch_num =
108108
std::ceil(static_cast<float>(ctr_data.size()) / batch_size) * thread_num;
109109

110+
std::vector<LoDTensor> out;
110111
for (size_t i = 0; i < batch_num; ++i) {
111-
std::vector<LoDTensor> out;
112112
reader.ReadNext(&out);
113113
ASSERT_EQ(out.size(), slots.size() + 1);
114114
auto& label_tensor = out.back();
@@ -126,5 +126,12 @@ TEST(CTR_READER, read_data) {
126126
tensor_6002.dims()[1] * sizeof(int64_t)),
127127
0);
128128
}
129+
reader.ReadNext(&out);
130+
ASSERT_EQ(out.size(), 0);
129131
ASSERT_EQ(queue->Size(), 0);
132+
reader.Shutdown();
133+
134+
reader.Start();
135+
reader.Shutdown();
136+
ASSERT_EQ(queue->Size(), 5);
130137
}

0 commit comments

Comments
 (0)