Skip to content

Commit aa3f505

Browse files
authored
Merge pull request #8841 from JiayiFeng/dev_double_buffer_for_cpp_reader
Basic double buffer for cpp reader
2 parents b341bac + 35e1e0d commit aa3f505

File tree

9 files changed

+175
-42
lines changed

9 files changed

+175
-42
lines changed

doc/design/cpp_data_feeding.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ class ReaderBase {
2020
PADDLE_ENFORCE(!shapes_.empty());
2121
}
2222
// Read the next batch of data. (A 'batch' can be only one instance)
23+
// If the next batch doesn't exist, the '*out' will be an empty std::vector.
2324
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
24-
// Show whether the next bacth exists.
25-
virtual bool HasNext() const = 0;
2625

2726
// Reinitialize the reader and read the file from the begin.
2827
virtual void ReInit() = 0;

paddle/fluid/framework/reader.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ class ReaderBase {
2626
PADDLE_ENFORCE(!shapes_.empty());
2727
}
2828
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
29-
virtual bool HasNext() const = 0;
3029

3130
virtual void ReInit() = 0;
3231

@@ -52,8 +51,6 @@ class DecoratedReader : public ReaderBase {
5251
PADDLE_ENFORCE_NOT_NULL(reader_);
5352
}
5453

55-
bool HasNext() const override { return reader_->HasNext(); }
56-
5754
void ReInit() override { reader_->ReInit(); }
5855

5956
protected:
@@ -69,7 +66,6 @@ class ReaderHolder {
6966
ReaderBase* Get() const { return reader_.get(); }
7067

7168
void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); }
72-
bool HasNext() const { return reader_->HasNext(); }
7369
void ReInit() { reader_->ReInit(); }
7470

7571
DDim shape(size_t idx) const { return reader_->shape(idx); }

paddle/fluid/operators/read_op.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,16 @@ class ReadOp : public framework::OperatorBase {
6060
const platform::Place& dev_place) const override {
6161
framework::ReaderHolder* reader =
6262
scope.FindVar(Input("Reader"))->GetMutable<framework::ReaderHolder>();
63-
if (!reader->HasNext()) {
63+
std::vector<std::string> out_arg_names = Outputs("Out");
64+
std::vector<framework::LoDTensor> ins;
65+
reader->ReadNext(&ins);
66+
if (ins.empty()) {
6467
reader->ReInit();
68+
reader->ReadNext(&ins);
6569
PADDLE_ENFORCE(
66-
reader->HasNext(),
70+
!ins.empty(),
6771
"Reader can not read the next data even it has been re-initialized.");
6872
}
69-
std::vector<std::string> out_arg_names = Outputs("Out");
70-
std::vector<framework::LoDTensor> ins;
71-
reader->ReadNext(&ins);
7273
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size());
7374
for (size_t i = 0; i < ins.size(); ++i) {
7475
auto* out =

paddle/fluid/operators/reader/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ cc_library(reader_op_registry SRCS reader_op_registry.cc DEPS operator op_regist
22
op_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc DEPS reader_op_registry)
33
op_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc DEPS reader_op_registry)
44
op_library(create_batch_reader_op SRCS create_batch_reader_op.cc DEPS reader_op_registry)
5-
set(READER_LIBRARY create_random_data_generator_op create_shuffle_reader_op create_batch_reader_op PARENT_SCOPE)
5+
op_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc DEPS reader_op_registry)
6+
set(READER_LIBRARY create_random_data_generator_op create_shuffle_reader_op create_batch_reader_op create_double_buffer_reader_op PARENT_SCOPE)

paddle/fluid/operators/reader/create_batch_reader_op.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,10 @@ void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) {
6868
buffer_.clear();
6969
buffer_.reserve(batch_size_);
7070
for (int i = 0; i < batch_size_; ++i) {
71-
if (reader_->HasNext()) {
72-
buffer_.push_back(std::vector<framework::LoDTensor>());
73-
reader_->ReadNext(&buffer_.back());
74-
} else {
71+
buffer_.push_back(std::vector<framework::LoDTensor>());
72+
reader_->ReadNext(&buffer_.back());
73+
if (buffer_.back().empty()) {
74+
buffer_.pop_back();
7575
break;
7676
}
7777
}
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <thread>
16+
#include "paddle/fluid/framework/channel.h"
17+
#include "paddle/fluid/operators/reader/reader_op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
namespace reader {
22+
23+
static constexpr size_t kDoubleBufferSize = 2;
24+
25+
class DoubleBufferReader : public framework::DecoratedReader {
26+
public:
27+
explicit DoubleBufferReader(ReaderBase* reader)
28+
: DecoratedReader(reader),
29+
buffer_(framework::MakeChannel<std::vector<framework::LoDTensor>>(
30+
kDoubleBufferSize)) {
31+
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
32+
prefetch.detach();
33+
}
34+
35+
void ReadNext(std::vector<framework::LoDTensor>* out) override;
36+
void ReInit() override;
37+
38+
~DoubleBufferReader() { buffer_->Close(); }
39+
40+
private:
41+
void PrefetchThreadFunc();
42+
43+
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
44+
};
45+
46+
class CreateDoubleBufferReaderOp : public framework::OperatorBase {
47+
public:
48+
using framework::OperatorBase::OperatorBase;
49+
50+
private:
51+
void RunImpl(const framework::Scope& scope,
52+
const platform::Place& dev_place) const override {
53+
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
54+
->Get<framework::ReaderHolder>();
55+
auto* out = scope.FindVar(Output("Out"))
56+
->template GetMutable<framework::ReaderHolder>();
57+
out->Reset(new DoubleBufferReader(underlying_reader.Get()));
58+
}
59+
};
60+
61+
class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
62+
public:
63+
CreateDoubleBufferReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
64+
: DecoratedReaderMakerBase(op_proto, op_checker) {
65+
AddComment(R"DOC(
66+
CreateDoubleBufferReader Operator
67+
68+
A double buffer reader takes another reader as its 'underlying reader'.
69+
It launches another thread to execute the 'underlying reader' asynchronously,
70+
which prevents reading process from blocking subsequent training.
71+
)DOC");
72+
}
73+
};
74+
75+
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
76+
out->clear();
77+
buffer_->Receive(out);
78+
}
79+
80+
void DoubleBufferReader::ReInit() {
81+
reader_->ReInit();
82+
buffer_->Close();
83+
// The existing prefetch thread will terminate for the buffer_ is closed.
84+
buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>(
85+
kDoubleBufferSize);
86+
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
87+
prefetch.detach();
88+
}
89+
90+
void DoubleBufferReader::PrefetchThreadFunc() {
91+
VLOG(5) << "A new prefetch thread starts.";
92+
while (true) {
93+
std::vector<framework::LoDTensor> batch;
94+
reader_->ReadNext(&batch);
95+
if (batch.empty()) {
96+
// EOF
97+
buffer_->Close();
98+
VLOG(5) << "Reached the end of the file. The prefetch thread terminates.";
99+
break;
100+
}
101+
if (!buffer_->Send(&batch)) {
102+
VLOG(5) << "WARNING: The double buffer channel has been closed. The "
103+
"prefetch thread terminates.";
104+
break;
105+
}
106+
}
107+
}
108+
109+
} // namespace reader
110+
} // namespace operators
111+
} // namespace paddle
112+
113+
namespace ops = paddle::operators::reader;
114+
REGISTER_DECORATED_READER_OPERATOR(create_double_buffer_reader,
115+
ops::CreateDoubleBufferReaderOp,
116+
ops::CreateDoubleBufferReaderOpMaker);

paddle/fluid/operators/reader/create_random_data_generator_op.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ class RandomDataGenerator : public framework::FileReader {
5050
}
5151
}
5252

53-
bool HasNext() const override { return true; }
54-
5553
void ReInit() override { return; }
5654

5755
private:

paddle/fluid/operators/reader/create_shuffle_reader_op.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ void ShuffleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
3939
buffer_.clear();
4040
buffer_.reserve(buffer_size_);
4141
for (int i = 0; i < buffer_size_; ++i) {
42-
if (reader_->HasNext()) {
43-
buffer_.push_back(std::vector<framework::LoDTensor>());
44-
reader_->ReadNext(&buffer_.back());
45-
} else {
42+
buffer_.push_back(std::vector<framework::LoDTensor>());
43+
reader_->ReadNext(&buffer_.back());
44+
if (buffer_.back().empty()) {
45+
buffer_.pop_back();
4646
break;
4747
}
4848
}

python/paddle/fluid/tests/test_cpp_reader.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,30 @@
1515
import paddle.v2 as paddle
1616
import paddle.fluid as fluid
1717
import numpy as np
18+
import sys
1819

19-
prog = fluid.framework.Program()
20-
block = prog.current_block()
20+
startup_prog = fluid.framework.Program()
21+
startup_block = startup_prog.current_block()
2122

22-
random_reader = block.create_var(
23+
random_reader = startup_block.create_var(
2324
type=fluid.core.VarDesc.VarType.READER, name="RandomDataGenerator")
2425
random_reader.desc.set_dtypes(
2526
[fluid.core.VarDesc.VarType.FP32, fluid.core.VarDesc.VarType.FP32])
27+
random_reader.persistable = True
28+
shuffle_reader = startup_block.create_var(
29+
type=fluid.core.VarDesc.VarType.READER, name="ShuffleReader")
30+
shuffle_reader.persistable = True
31+
batch_reader = startup_block.create_var(
32+
type=fluid.core.VarDesc.VarType.READER, name="BatchReader")
33+
batch_reader.persistable = True
34+
double_buffer = startup_block.create_var(
35+
type=fluid.core.VarDesc.VarType.READER, name="DoubleBuffer")
36+
double_buffer.persistable = True
37+
38+
main_prog = startup_prog.clone()
39+
main_block = main_prog.current_block()
2640

27-
create_random_data_generator_op = block.append_op(
41+
create_random_data_generator_op = startup_block.append_op(
2842
type="create_random_data_generator",
2943
outputs={"Out": random_reader},
3044
attrs={
@@ -34,37 +48,45 @@
3448
"max": 1.0,
3549
'lod_levels': [0, 0]
3650
})
37-
shuffle_reader = block.create_var(
38-
type=fluid.core.VarDesc.VarType.READER, name="ShuffleReader")
3951

40-
create_shuffle_reader_op = block.append_op(
52+
create_shuffle_reader_op = startup_block.append_op(
4153
type="create_shuffle_reader",
4254
inputs={"UnderlyingReader": random_reader},
4355
outputs={"Out": shuffle_reader},
4456
attrs={"buffer_size": 7})
4557

46-
batch_reader = block.create_var(
47-
type=fluid.core.VarDesc.VarType.READER, name="BatchReader")
48-
49-
create_batch_reader_op = block.append_op(
58+
create_batch_reader_op = startup_block.append_op(
5059
type="create_batch_reader",
5160
inputs={"UnderlyingReader": shuffle_reader},
5261
outputs={"Out": batch_reader},
5362
attrs={"batch_size": 10})
5463

55-
out1 = block.create_var(type=fluid.core.VarDesc.VarType.LOD_TENSOR, name="Out1")
56-
out2 = block.create_var(type=fluid.core.VarDesc.VarType.LOD_TENSOR, name="Out2")
64+
create_double_buffer_reader_op = startup_block.append_op(
65+
type="create_double_buffer_reader",
66+
inputs={"UnderlyingReader": batch_reader},
67+
outputs={"Out": double_buffer})
68+
69+
out1 = main_block.create_var(
70+
type=fluid.core.VarDesc.VarType.LOD_TENSOR, name="Out1")
71+
out2 = main_block.create_var(
72+
type=fluid.core.VarDesc.VarType.LOD_TENSOR, name="Out2")
5773

58-
read_op = block.append_op(
59-
type="read", inputs={"Reader": batch_reader},
74+
main_block.var("DoubleBuffer").desc.set_shapes(double_buffer.desc.shapes())
75+
main_block.var("DoubleBuffer").desc.set_dtypes(double_buffer.desc.dtypes())
76+
main_block.var("DoubleBuffer").desc.set_lod_levels(
77+
double_buffer.desc.lod_levels())
78+
79+
read_op = main_block.append_op(
80+
type="read",
81+
inputs={"Reader": double_buffer},
6082
outputs={"Out": [out1, out2]})
6183

6284
place = fluid.CPUPlace()
6385
exe = fluid.Executor(place)
6486

65-
[res1, res2] = exe.run(prog, fetch_list=[out1, out2])
66-
67-
if not (res1.shape == (10, 2) and res2.shape == (10, 1)):
68-
exit(1)
87+
exe.run(startup_prog)
6988

70-
exit(0)
89+
for i in range(1, 100):
90+
[res1, res2] = exe.run(main_prog, fetch_list=[out1, out2])
91+
if not (res1.shape == (10, 2) and res2.shape == (10, 1)):
92+
exit(1)

0 commit comments

Comments
 (0)