|
18 | 18 | namespace paddle {
|
19 | 19 | namespace operators {
|
20 | 20 | namespace reader {
|
21 |
| -BufferedReader::~BufferedReader() { |
22 |
| - reader_->Shutdown(); |
23 |
| - buffer_.clear(); |
24 |
| -} |
| 21 | +BufferedReader::~BufferedReader() { reader_->Shutdown(); } |
25 | 22 | BufferedReader::BufferedReader(
|
26 | 23 | const std::shared_ptr<framework::ReaderBase> &reader,
|
27 | 24 | const platform::Place &place, size_t buffer_size)
|
28 | 25 | : framework::DecoratedReader(reader),
|
29 | 26 | thread_pool_(1),
|
30 | 27 | place_(place),
|
31 | 28 | buffer_size_(buffer_size) {
|
| 29 | + cpu_buffer_.resize(buffer_size); |
| 30 | + gpu_buffer_.resize(buffer_size); |
32 | 31 | AppendFutureToBatchSize();
|
33 | 32 | }
|
34 | 33 | void BufferedReader::AppendFutureToBatchSize() {
|
35 |
| - while (buffer_.size() < buffer_size_) { |
36 |
| - AppendFuture(); |
| 34 | + PADDLE_ENFORCE_EQ(position_.size(), 0U); |
| 35 | + for (size_t i = 0; i < buffer_size_; ++i) { |
| 36 | + AppendFuture(i); |
37 | 37 | }
|
38 | 38 | }
|
39 |
| -void BufferedReader::AppendFuture() { |
40 |
| - buffer_.emplace_back(thread_pool_.enqueue([this] { |
41 |
| - TensorVec cpu_buffer; |
42 |
| - reader_->ReadNext(&cpu_buffer); |
43 |
| - if (platform::is_gpu_place(place_)) { |
44 |
| - TensorVec gpu_buffer; |
| 39 | +void BufferedReader::AppendFuture(size_t i) { |
| 40 | + position_.emplace(thread_pool_.enqueue([this, i]() -> size_t { |
| 41 | + TensorVec &cpu = cpu_buffer_[i]; |
| 42 | + reader_->ReadNext(&cpu); |
45 | 43 |
|
46 |
| - for (size_t i = 0; i < cpu_buffer.size(); ++i) { |
47 |
| - gpu_buffer.emplace_back(); |
48 |
| - framework::TensorCopySync(cpu_buffer[i], place_, &gpu_buffer.back()); |
49 |
| - } |
| 44 | + if (cpu.empty()) { |
| 45 | + return -1UL; |
| 46 | + } |
50 | 47 |
|
51 |
| - cpu_buffer = gpu_buffer; |
| 48 | + if (platform::is_gpu_place(place_)) { |
| 49 | + TensorVec &gpu = gpu_buffer_[i]; |
| 50 | + gpu.resize(cpu.size()); |
| 51 | + for (size_t i = 0; i < cpu.size(); ++i) { |
| 52 | + framework::TensorCopySync(cpu[i], place_, &gpu[i]); |
| 53 | + } |
52 | 54 | }
|
53 |
| - return cpu_buffer; |
| 55 | + return i; |
54 | 56 | }));
|
55 | 57 | }
|
56 | 58 | void BufferedReader::ShutdownImpl() {
|
57 | 59 | reader_->Shutdown();
|
58 |
| - buffer_.clear(); |
| 60 | + while (!position_.empty()) { |
| 61 | + position_.pop(); |
| 62 | + } |
59 | 63 | }
|
60 | 64 | void BufferedReader::StartImpl() {
|
61 | 65 | reader_->Start();
|
62 | 66 | AppendFutureToBatchSize();
|
63 | 67 | }
|
64 | 68 | void BufferedReader::ReadNextImpl(std::vector<framework::LoDTensor> *out) {
|
65 |
| - PADDLE_ENFORCE_EQ(buffer_.size(), buffer_size_); |
66 |
| - *out = buffer_.front().get(); |
67 |
| - buffer_.pop_front(); |
68 |
| - AppendFuture(); |
| 69 | + if (position_.empty()) { |
| 70 | + out->clear(); |
| 71 | + return; |
| 72 | + } |
| 73 | + size_t i = position_.front().get(); |
| 74 | + position_.pop(); |
| 75 | + |
| 76 | + if (i == -1UL) { |
| 77 | + ReadNextImpl(out); |
| 78 | + return; |
| 79 | + } |
| 80 | + |
| 81 | + *out = platform::is_gpu_place(place_) ? gpu_buffer_[i] : cpu_buffer_[i]; |
| 82 | + AppendFuture(i); |
69 | 83 | }
|
70 | 84 |
|
71 | 85 | } // namespace reader
|
|
0 commit comments