@@ -30,8 +30,11 @@ namespace paddle {
30
30
namespace operators {
31
31
namespace reader {
32
32
33
+ enum ReaderThreadStatus { Running, Stopped };
34
+
33
35
void ReadThread (const std::vector<std::string>& file_list,
34
36
const std::vector<std::string>& slots, int batch_size,
37
+ int thread_id, std::vector<ReaderThreadStatus>* thread_status,
35
38
std::shared_ptr<LoDTensorBlockingQueue> queue);
36
39
37
40
class CTRReader : public framework ::FileReader {
@@ -40,13 +43,16 @@ class CTRReader : public framework::FileReader {
40
43
int batch_size, int thread_num,
41
44
const std::vector<std::string>& slots,
42
45
const std::vector<std::string>& file_list)
43
- : thread_num_(thread_num),
44
- batch_size_(batch_size),
45
- slots_(slots),
46
- file_list_(file_list) {
46
+ : batch_size_(batch_size), slots_(slots), file_list_(file_list) {
47
47
PADDLE_ENFORCE (queue != nullptr , " LoDTensorBlockingQueue must not be null" );
48
+ PADDLE_ENFORCE_GT (file_list.size (), 0 , " file list should not be empty" );
49
+ thread_num_ =
50
+ file_list_.size () > thread_num_ ? thread_num_ : file_list_.size ();
48
51
queue_ = queue;
49
52
SplitFiles ();
53
+ for (int i = 0 ; i < thread_num; ++i) {
54
+ read_thread_status_.push_back (Stopped);
55
+ }
50
56
}
51
57
52
58
~CTRReader () { queue_->Close (); }
@@ -69,28 +75,29 @@ class CTRReader : public framework::FileReader {
69
75
void Start () override {
70
76
VLOG (3 ) << " Start reader" ;
71
77
queue_->ReOpen ();
72
- for (int i = 0 ; i < file_groups_.size (); i++) {
73
- read_threads_.emplace_back (new std::thread (std::bind (
74
- &ReadThread, file_groups_[i], slots_, batch_size_, queue_)));
78
+ for (int thread_id = 0 ; thread_id < file_groups_.size (); thread_id++) {
79
+ read_threads_.emplace_back (new std::thread (
80
+ std::bind (&ReadThread, file_groups_[thread_id], slots_, batch_size_,
81
+ thread_id, &read_thread_status_, queue_)));
75
82
}
76
83
}
77
84
78
85
private:
79
86
void SplitFiles () {
80
- file_groups_.resize (file_list_.size () > thread_num_ ? thread_num_
81
- : file_list_.size ());
87
+ file_groups_.resize (thread_num_);
82
88
for (int i = 0 ; i < file_list_.size (); ++i) {
83
89
file_groups_[i % thread_num_].push_back (file_list_[i]);
84
90
}
85
91
}
86
92
87
93
private:
88
- const int thread_num_;
94
+ int thread_num_;
89
95
const int batch_size_;
90
96
const std::vector<std::string> slots_;
91
97
const std::vector<std::string> file_list_;
92
98
std::shared_ptr<LoDTensorBlockingQueue> queue_;
93
99
std::vector<std::unique_ptr<std::thread>> read_threads_;
100
+ std::vector<ReaderThreadStatus> read_thread_status_;
94
101
std::vector<std::vector<std::string>> file_groups_;
95
102
};
96
103
0 commit comments