1
+ // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include " paddle/fluid/framework/channel.h"
16
+ #include " paddle/fluid/operators/reader/reader_op_registry.h"
17
+
18
+ namespace paddle {
19
+ namespace operators {
20
+ namespace reader {
21
+
22
+ class MultipleReader : public framework ::ReaderBase {
23
+ public:
24
+ struct Quota {};
25
+
26
+ MultipleReader (const std::vector<std::string>& file_names,
27
+ const std::vector<framework::DDim>& dims, size_t thread_num)
28
+ : file_names_(file_names), dims_(dims), thread_num_(thread_num) {
29
+ PADDLE_ENFORCE_GT (thread_num_, 0 );
30
+ StartNewScheduler ();
31
+ }
32
+
33
+ void ReadNext (std::vector<framework::LoDTensor>* out) override ;
34
+ bool HasNext () const override ;
35
+ void ReInit () override ;
36
+
37
+ private:
38
+ void StartNewScheduler ();
39
+ void ScheduleThreadFunc ();
40
+ void PrefetchThreadFunc (std::string file_name);
41
+
42
+ std::vector<std::string> file_names_;
43
+ std::vector<framework::DDim> dims_;
44
+ size_t thread_num_;
45
+ framework::Channel<size_t >* waiting_file_idx_;
46
+ framework::Channel<Quota>* thread_quotas_;
47
+ framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
48
+ mutable std::vector<framework::LoDTensor> local_buffer_;
49
+ };
50
+
51
+ void MultipleReader::ReadNext (std::vector<framework::LoDTensor>* out) {
52
+ if (!HasNext ()) {
53
+ PADDLE_THROW (" There is no next data!" );
54
+ }
55
+
56
+ if (local_buffer_.empty ()) {
57
+ buffer_->Receive (&local_buffer_);
58
+ }
59
+ *out = local_buffer_;
60
+ local_buffer_.clear ();
61
+ }
62
+
63
+ bool MultipleReader::HasNext () const {
64
+ return local_buffer_.empty () ? buffer_->Receive (&local_buffer_) : true ;
65
+ }
66
+
67
+ void MultipleReader::ReInit () {
68
+ buffer_->Close ();
69
+ thread_quotas_->Close ();
70
+ waiting_file_idx_->Close ();
71
+ local_buffer_.clear ();
72
+
73
+ StartNewScheduler ();
74
+ }
75
+
76
+ void MultipleReader::StartNewScheduler () {
77
+ waiting_file_idx_ = framework::MakeChannel<size_t >(file_names_.size ());
78
+ thread_quotas_ = framework::MakeChannel<Quota>(thread_num_);
79
+ buffer_ =
80
+ framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num_);
81
+
82
+ for (size_t i = 0 ; i < file_names_.size (); ++i) {
83
+ waiting_file_idx_->Send (&i);
84
+ }
85
+ waiting_file_idx_->Close ();
86
+ for (size_t i = 0 ; i < thread_num_; ++i) {
87
+ Quota quota;
88
+ thread_quotas_->Send ("a);
89
+ }
90
+
91
+ std::thread scheduler ([this ] { ScheduleThreadFunc (); });
92
+ scheduler.detach ();
93
+ }
94
+
95
+ void MultipleReader::ScheduleThreadFunc () {
96
+ VLOG (5 ) << " MultipleReader schedule thread starts." ;
97
+ size_t completed_thread_num = 0 ;
98
+ Quota quota;
99
+ while (thread_quotas_->Receive ("a)) {
100
+ size_t file_idx;
101
+ if (waiting_file_idx_->Receive (&file_idx)) {
102
+ // Still have files to read. Start a new prefetch thread.
103
+ std::string file_name = file_names_[file_idx];
104
+ std::thread prefetcher (
105
+ [this , file_name] { PrefetchThreadFunc (file_name); });
106
+ prefetcher.detach ();
107
+ } else {
108
+ // No more file to read.
109
+ ++completed_thread_num;
110
+ if (completed_thread_num == thread_num_) {
111
+ thread_quotas_->Close ();
112
+ buffer_->Close ();
113
+ break ;
114
+ }
115
+ }
116
+ }
117
+ VLOG (5 ) << " MultipleReader schedule thread terminates." ;
118
+ }
119
+
120
+ void MultipleReader::PrefetchThreadFunc (std::string file_name) {
121
+ VLOG (5 ) << " The prefetch thread of file '" << file_name << " ' starts." ;
122
+ std::unique_ptr<framework::ReaderBase> reader =
123
+ CreateReaderByFileName (file_name, dims_);
124
+ while (reader->HasNext ()) {
125
+ std::vector<framework::LoDTensor> ins;
126
+ reader->ReadNext (&ins);
127
+ if (!buffer_->Send (&ins)) {
128
+ VLOG (5 ) << " WARNING: The buffer channel has been closed. The prefetch "
129
+ " thread of file '"
130
+ << file_name << " ' will terminate." ;
131
+ break ;
132
+ }
133
+ }
134
+ Quota quota;
135
+ thread_quotas_->Send ("a);
136
+ VLOG (5 ) << " The prefetch thread of file '" << file_name << " ' terminates." ;
137
+ }
138
+
139
+ class OpenFilesOp : public framework ::OperatorBase {
140
+ public:
141
+ using framework::OperatorBase::OperatorBase;
142
+
143
+ private:
144
+ void RunImpl (const framework::Scope& scope,
145
+ const platform::Place& dev_place) const override {
146
+ const auto & shape_concat = Attr<std::vector<int >>(" shape_concat" );
147
+ const auto & ranks = Attr<std::vector<int >>(" ranks" );
148
+ PADDLE_ENFORCE (!shape_concat.empty () && !ranks.empty ());
149
+ PADDLE_ENFORCE_EQ (std::accumulate (ranks.begin (), ranks.end (), 0 ),
150
+ int (shape_concat.size ()),
151
+ " The accumulate of all ranks should be equal to the "
152
+ " shape concat's length." );
153
+ const auto & file_names = Attr<std::vector<std::string>>(" file_names" );
154
+ PADDLE_ENFORCE (!file_names.empty (), " No file to be read!" );
155
+ const size_t thread_num = Attr<int >(" thread_num" );
156
+
157
+ auto * out = scope.FindVar (Output (" Out" ))
158
+ ->template GetMutable <framework::ReaderHolder>();
159
+ out->Reset (new MultipleReader (
160
+ file_names, RestoreShapes (shape_concat, ranks), thread_num));
161
+ }
162
+ };
163
+
164
+ class OpenFilesOpMaker : public framework ::OpProtoAndCheckerMaker {
165
+ public:
166
+ OpenFilesOpMaker (OpProto* op_proto, OpAttrChecker* op_checker)
167
+ : OpProtoAndCheckerMaker(op_proto, op_checker) {
168
+ AddComment (R"DOC(
169
+ OpenFiles Operator
170
+
171
+ An OpenFilesOp creates a MultipleReader, which is able to
172
+ read data multi-threaded from multiple files.
173
+ )DOC" );
174
+ AddOutput (" Out" , " (ReaderHolder) The created MultipleReader." );
175
+ AddAttr<std::vector<int >>(" shape_concat" ,
176
+ " The concat of all data's shapes." );
177
+ AddAttr<std::vector<int >>(
178
+ " ranks" ,
179
+ " The ranks of each data."
180
+ " e.g."
181
+ " shape_concat = [2,3,4,5,6]"
182
+ " ranks = [3,2]"
183
+ " It means the reader will generate two data each time,"
184
+ " whose shapes are [2,3,4] and [5,6] respectively." );
185
+ AddAttr<std::vector<int >>(" lod_levels" , " The LoD levels of each data." );
186
+ AddAttr<std::vector<std::string>>(" file_names" , " Files to be read." );
187
+ AddAttr<int >(" thread_num" , " The maximal concurrent prefetch thread number." )
188
+ .GreaterThan (0 );
189
+ }
190
+ };
191
+
192
+ } // namespace reader
193
+ } // namespace operators
194
+ } // namespace paddle
195
+
196
+ namespace reader = paddle::operators::reader;
197
+
198
+ REGISTER_FILE_READER_OPERATOR (open_files, reader::OpenFilesOp,
199
+ reader::OpenFilesOpMaker);
0 commit comments