@@ -21,6 +21,22 @@ namespace reader {
21
21
22
22
class MultipleReader : public framework ::ReaderBase {
23
23
public:
24
+ class ThreadBufferMap {
25
+ public:
26
+ std::vector<framework::LoDTensor>& operator [](
27
+ const std::thread::id& thread_id) {
28
+ std::lock_guard<std::mutex> lock (mutex_);
29
+ return buffer_[thread_id];
30
+ }
31
+
32
+ void Clear () { buffer_.clear (); }
33
+
34
+ private:
35
+ std::mutex mutex_;
36
+ std::unordered_map<std::thread::id, std::vector<framework::LoDTensor>>
37
+ buffer_;
38
+ };
39
+
24
40
MultipleReader (const std::vector<std::string>& file_names,
25
41
const std::vector<framework::DDim>& dims, size_t thread_num)
26
42
: file_names_(file_names), dims_(dims) {
@@ -47,28 +63,27 @@ class MultipleReader : public framework::ReaderBase {
47
63
framework::Channel<size_t >* waiting_file_idx_;
48
64
framework::Channel<size_t >* available_thread_idx_;
49
65
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
50
- mutable std::vector<framework::LoDTensor> local_buffer_ ;
66
+ mutable ThreadBufferMap thread_buffer_map_ ;
51
67
};
52
68
53
69
void MultipleReader::ReadNext (std::vector<framework::LoDTensor>* out) {
54
70
if (!HasNext ()) {
55
71
PADDLE_THROW (" There is no next data!" );
56
72
}
57
-
58
- if (local_buffer_.empty ()) {
59
- buffer_->Receive (&local_buffer_);
60
- }
61
- *out = local_buffer_;
62
- local_buffer_.clear ();
73
+ auto & thread_local_buffer = thread_buffer_map_[std::this_thread::get_id ()];
74
+ *out = thread_local_buffer;
75
+ thread_local_buffer.clear ();
63
76
}
64
77
65
78
bool MultipleReader::HasNext () const {
66
- return local_buffer_.empty () ? buffer_->Receive (&local_buffer_) : true ;
79
+ auto & thread_local_buffer = thread_buffer_map_[std::this_thread::get_id ()];
80
+ return thread_local_buffer.empty () ? buffer_->Receive (&thread_local_buffer)
81
+ : true ;
67
82
}
68
83
69
84
void MultipleReader::ReInit () {
70
85
EndScheduler ();
71
- local_buffer_. clear ();
86
+ thread_buffer_map_. Clear ();
72
87
StartNewScheduler ();
73
88
}
74
89
@@ -176,7 +191,7 @@ class OpenFilesOp : public framework::OperatorBase {
176
191
const auto & ranks = Attr<std::vector<int >>(" ranks" );
177
192
PADDLE_ENFORCE (!shape_concat.empty () && !ranks.empty ());
178
193
PADDLE_ENFORCE_EQ (std::accumulate (ranks.begin (), ranks.end (), 0 ),
179
- int (shape_concat.size ()),
194
+ static_cast < int > (shape_concat.size ()),
180
195
" The accumulate of all ranks should be equal to the "
181
196
" shape concat's length." );
182
197
const auto & file_names = Attr<std::vector<std::string>>(" file_names" );
0 commit comments