@@ -24,15 +24,31 @@ static constexpr size_t kDoubleBufferSize = 2;
24
24
25
25
class DoubleBufferReader : public framework ::DecoratedReader {
26
26
public:
27
+ struct Item {
28
+ Item () : ctx_(nullptr ) {}
29
+
30
+ std::vector<framework::LoDTensor> payloads_;
31
+ platform::DeviceContext* ctx_;
32
+ };
33
+
27
34
explicit DoubleBufferReader (
28
35
ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
29
36
: DecoratedReader(reader), place_(target_place) {
37
+ for (size_t i = 0 ; i < kDoubleBufferSize ; ++i) {
38
+ if (platform::is_gpu_place (place_)) {
39
+ #ifdef PADDLE_WITH_CUDA
40
+ ctxs_.emplace_back (new platform::CUDADeviceContext (
41
+ boost::get<platform::CUDAPlace>(place_)));
42
+ #else
43
+ #endif
44
+ }
45
+ }
46
+
30
47
start_thread ();
31
48
}
32
49
33
50
void start_thread () {
34
- buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>(
35
- kDoubleBufferSize );
51
+ buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize );
36
52
std::thread prefetch ([this ] { PrefetchThreadFunc (); });
37
53
prefetch.detach ();
38
54
}
@@ -47,9 +63,10 @@ class DoubleBufferReader : public framework::DecoratedReader {
47
63
private:
48
64
void PrefetchThreadFunc ();
49
65
50
- framework::Channel<std::vector<framework::LoDTensor> >* buffer_;
66
+ framework::Channel<Item >* buffer_;
51
67
platform::Place place_;
52
- mutable std::vector<framework::LoDTensor> local_buffer_;
68
+ std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
69
+ mutable Item local_buffer_;
53
70
};
54
71
55
72
class CreateDoubleBufferReaderOp : public framework ::OperatorBase {
@@ -104,12 +121,14 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
104
121
};
105
122
106
123
void DoubleBufferReader::ReadNext (std::vector<framework::LoDTensor>* out) {
107
- out->clear ();
108
- if (local_buffer_.empty ()) {
109
- buffer_->Receive (out);
110
- } else {
111
- *out = local_buffer_;
112
- local_buffer_.clear ();
124
+ if (local_buffer_.payloads_ .empty ()) {
125
+ buffer_->Receive (&local_buffer_);
126
+ }
127
+
128
+ *out = local_buffer_.payloads_ ;
129
+ local_buffer_.payloads_ .clear ();
130
+ if (local_buffer_.ctx_ ) {
131
+ local_buffer_.ctx_ ->Wait ();
113
132
}
114
133
}
115
134
@@ -121,16 +140,22 @@ void DoubleBufferReader::ReInit() {
121
140
122
141
void DoubleBufferReader::PrefetchThreadFunc () {
123
142
VLOG (5 ) << " A new prefetch thread starts." ;
143
+ size_t gpu_ctx_offset = 0 ;
124
144
while (reader_->HasNext ()) {
125
- std::vector<framework::LoDTensor> batch;
126
- reader_->ReadNext (&batch);
145
+ Item batch;
146
+ reader_->ReadNext (&batch. payloads_ );
127
147
if (platform::is_gpu_place (place_)) {
128
148
std::vector<framework::LoDTensor> gpu_batch;
129
- gpu_batch.resize (batch.size ());
130
- for (size_t i = 0 ; i < batch.size (); ++i) {
131
- framework::TensorCopy (batch[i], place_, &gpu_batch[i]);
132
- gpu_batch[i].set_lod (batch[i].lod ());
149
+ auto & gpu_ctx = this ->ctxs_ [gpu_ctx_offset++];
150
+ gpu_ctx_offset %= this ->ctxs_ .size ();
151
+ gpu_batch.resize (batch.payloads_ .size ());
152
+ for (size_t i = 0 ; i < batch.payloads_ .size (); ++i) {
153
+ framework::TensorCopy (batch.payloads_ [i], place_, *gpu_ctx,
154
+ &gpu_batch[i]);
155
+ gpu_batch[i].set_lod (batch.payloads_ [i].lod ());
133
156
}
157
+ batch.ctx_ = gpu_ctx.get ();
158
+ std::swap (gpu_batch, batch.payloads_ );
134
159
}
135
160
136
161
if (!buffer_->Send (&batch)) {
@@ -143,7 +168,7 @@ void DoubleBufferReader::PrefetchThreadFunc() {
143
168
}
144
169
145
170
bool DoubleBufferReader::HasNext () const {
146
- if (local_buffer_.empty ()) {
171
+ if (local_buffer_.payloads_ . empty ()) {
147
172
bool ok = buffer_->Receive (&local_buffer_);
148
173
return ok;
149
174
} else {
0 commit comments