@@ -30,19 +30,23 @@ namespace paddle {
30
30
namespace operators {
31
31
namespace reader {
32
32
33
+ void ReadThread (const std::vector<std::string>& file_list,
34
+ const std::vector<std::string>& slots, int batch_size,
35
+ std::shared_ptr<LoDTensorBlockingQueue> queue);
36
+
33
37
class CTRReader : public framework ::FileReader {
34
38
public:
35
39
explicit CTRReader (const std::shared_ptr<LoDTensorBlockingQueue>& queue,
36
40
int batch_size, int thread_num,
37
41
const std::vector<std::string>& slots,
38
42
const std::vector<std::string>& file_list)
39
- : framework::FileReader() {
40
- thread_num_ = thread_num;
41
- batch_size_ = batch_size;
43
+ : thread_num_(thread_num),
44
+ batch_size_(batch_size),
45
+ slots_(slots),
46
+ file_list_(file_list) {
42
47
PADDLE_ENFORCE (queue != nullptr , " LoDTensorBlockingQueue must not be null" );
43
48
queue_ = queue;
44
- slots_ = slots;
45
- file_list_ = file_list;
49
+ SplitFiles ();
46
50
}
47
51
48
52
~CTRReader () { queue_->Close (); }
@@ -53,30 +57,41 @@ class CTRReader : public framework::FileReader {
53
57
if (!success) out->clear ();
54
58
}
55
59
56
- void Shutdown () override { queue_->Close (); }
60
+ void Shutdown () override {
61
+ VLOG (3 ) << " Shutdown reader" ;
62
+ for (auto & read_thread : read_threads_) {
63
+ read_thread->join ();
64
+ }
65
+ read_threads_.clear ();
66
+ queue_->Close ();
67
+ }
57
68
58
69
void Start () override {
70
+ VLOG (3 ) << " Start reader" ;
59
71
queue_->ReOpen ();
60
- // for (int i = 0; i < thread_num_; i++) {
61
- // read_threads_.emplace_back(
62
- // new std::thread(std::bind(&CTRReader::ReadThread, this,
63
- // file_list_,
64
- // slots_, batch_size_, queue_)));
65
- // }
72
+ for (int i = 0 ; i < file_groups_.size (); i++) {
73
+ read_threads_.emplace_back (new std::thread (std::bind (
74
+ &ReadThread, file_groups_[i], slots_, batch_size_, queue_)));
75
+ }
66
76
}
67
77
68
78
private:
69
- void ReadThread (const std::vector<std::string>& file_list,
70
- const std::vector<std::string>& slots, int batch_size,
71
- std::shared_ptr<LoDTensorBlockingQueue> queue);
79
+ void SplitFiles () {
80
+ file_groups_.resize (file_list_.size () > thread_num_ ? thread_num_
81
+ : file_list_.size ());
82
+ for (int i = 0 ; i < file_list_.size (); ++i) {
83
+ file_groups_[i % thread_num_].push_back (file_list_[i]);
84
+ }
85
+ }
72
86
73
87
private:
88
+ const int thread_num_;
89
+ const int batch_size_;
90
+ const std::vector<std::string> slots_;
91
+ const std::vector<std::string> file_list_;
74
92
std::shared_ptr<LoDTensorBlockingQueue> queue_;
75
93
std::vector<std::unique_ptr<std::thread>> read_threads_;
76
- int thread_num_;
77
- int batch_size_;
78
- std::vector<std::string> slots_;
79
- std::vector<std::string> file_list_;
94
+ std::vector<std::vector<std::string>> file_groups_;
80
95
};
81
96
82
97
} // namespace reader
0 commit comments