Skip to content

Commit 48f213e

Browse files
authored
Merge pull request #8991 from reyoung/feature/shuffle_reader
Feature/shuffle reader
2 parents 881c522 + 127b371 commit 48f213e

File tree

12 files changed

+270
-127
lines changed

12 files changed

+270
-127
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -445,15 +445,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
445445
}
446446

447447
std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
448-
Variable* var = scope_.FindVar(name);
449-
if (var->IsType<ReaderHolder>()) {
450-
return var->Get<ReaderHolder>().shapes();
451-
} else {
452-
PADDLE_THROW(
453-
"Only ReaderHolder support 'GetRepeatedDims', but Variable %s's "
454-
"type_id is %s.",
455-
name, var->Type().name());
456-
}
448+
PADDLE_THROW("Only compile time support this method");
457449
}
458450

459451
void SetDim(const std::string& name, const DDim& dim) override {
@@ -470,15 +462,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
470462

471463
void SetRepeatedDims(const std::string& name,
472464
const std::vector<DDim>& dims) override {
473-
Variable* var = scope_.FindVar(name);
474-
if (var->IsType<ReaderHolder>()) {
475-
var->GetMutable<ReaderHolder>()->set_shapes(dims);
476-
} else {
477-
PADDLE_THROW(
478-
"Only ReaderHolder support 'SetRepeatedDims', but Variable %s's "
479-
"type_id is %s.",
480-
name, var->Type().name());
481-
}
465+
PADDLE_THROW("Only compile time support this method");
482466
}
483467

484468
proto::VarType::Type GetVarType(const std::string& name) const override {

paddle/fluid/framework/reader.cc

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,22 @@
1616

1717
namespace paddle {
1818
namespace framework {
19+
ReaderBase::~ReaderBase() {}
1920

20-
DDim ReaderBase::shape(size_t idx) const {
21-
PADDLE_ENFORCE_LT(
22-
idx, shapes_.size(),
23-
"Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx,
24-
shapes_.size());
25-
return shapes_[idx];
26-
}
21+
FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}
22+
23+
void FileReader::ReadNext(std::vector<LoDTensor> *out) {
24+
ReadNextImpl(out);
25+
PADDLE_ENFORCE_EQ(out->size(), dims_.size());
26+
for (size_t i = 0; i < dims_.size(); ++i) {
27+
auto &actual = out->at(i).dims();
28+
auto &expect = dims_[i];
2729

30+
PADDLE_ENFORCE_EQ(actual.size(), expect.size());
31+
for (int j = 0; j < actual.size(); ++j) {
32+
PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1);
33+
}
34+
}
35+
}
2836
} // namespace framework
2937
} // namespace paddle

paddle/fluid/framework/reader.h

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,40 +16,29 @@
1616

1717
#include "paddle/fluid/framework/ddim.h"
1818
#include "paddle/fluid/framework/lod_tensor_array.h"
19+
#include "paddle/fluid/platform/place.h"
20+
21+
#include <memory>
22+
#include <thread>
23+
#include <vector>
1924

2025
namespace paddle {
2126
namespace framework {
2227

2328
class ReaderBase {
2429
public:
25-
explicit ReaderBase(const std::vector<DDim>& shapes) : shapes_(shapes) {
26-
PADDLE_ENFORCE(!shapes_.empty());
27-
}
2830
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
2931

3032
virtual void ReInit() = 0;
3133

32-
DDim shape(size_t idx) const;
33-
std::vector<DDim> shapes() const { return shapes_; }
34-
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }
35-
3634
virtual bool HasNext() const = 0;
3735

38-
virtual ~ReaderBase() {}
39-
40-
protected:
41-
std::vector<DDim> shapes_;
42-
};
43-
44-
class FileReader : public ReaderBase {
45-
public:
46-
explicit FileReader(const std::vector<DDim>& shapes) : ReaderBase(shapes) {}
36+
virtual ~ReaderBase();
4737
};
4838

4939
class DecoratedReader : public ReaderBase {
5040
public:
51-
explicit DecoratedReader(ReaderBase* reader)
52-
: ReaderBase(reader->shapes()), reader_(reader) {
41+
explicit DecoratedReader(ReaderBase* reader) : ReaderBase(), reader_(reader) {
5342
PADDLE_ENFORCE_NOT_NULL(reader_);
5443
}
5544

@@ -61,6 +50,19 @@ class DecoratedReader : public ReaderBase {
6150
ReaderBase* reader_;
6251
};
6352

53+
class FileReader : public ReaderBase {
54+
public:
55+
explicit FileReader(const std::vector<DDim>& dims);
56+
57+
void ReadNext(std::vector<LoDTensor>* out) override;
58+
59+
protected:
60+
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
61+
62+
private:
63+
std::vector<DDim> dims_;
64+
};
65+
6466
// The ReaderHolder is used as reader' unified wrapper,
6567
// making it easier to access different type reader in Variables.
6668
class ReaderHolder {
@@ -78,19 +80,6 @@ class ReaderHolder {
7880
reader_->ReInit();
7981
}
8082

81-
DDim shape(size_t idx) const {
82-
PADDLE_ENFORCE_NOT_NULL(reader_);
83-
return reader_->shape(idx);
84-
}
85-
std::vector<DDim> shapes() const {
86-
PADDLE_ENFORCE_NOT_NULL(reader_);
87-
return reader_->shapes();
88-
}
89-
void set_shapes(const std::vector<DDim>& shapes) {
90-
PADDLE_ENFORCE_NOT_NULL(reader_);
91-
reader_->set_shapes(shapes);
92-
}
93-
9483
bool HasNext() const { return reader_->HasNext(); }
9584

9685
private:

paddle/fluid/operators/reader/create_double_buffer_reader_op.cc

Lines changed: 88 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,31 @@ static constexpr size_t kDoubleBufferSize = 2;
2424

2525
class DoubleBufferReader : public framework::DecoratedReader {
2626
public:
27-
explicit DoubleBufferReader(ReaderBase* reader)
28-
: DecoratedReader(reader),
29-
buffer_(framework::MakeChannel<std::vector<framework::LoDTensor>>(
30-
kDoubleBufferSize)) {
31-
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
27+
struct Item {
28+
Item() : ctx_(nullptr) {}
29+
30+
std::vector<framework::LoDTensor> payloads_;
31+
platform::DeviceContext* ctx_;
32+
};
33+
34+
explicit DoubleBufferReader(
35+
ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
36+
: DecoratedReader(reader), place_(target_place) {
37+
for (size_t i = 0; i < kDoubleBufferSize; ++i) {
38+
if (platform::is_gpu_place(place_)) {
39+
#ifdef PADDLE_WITH_CUDA
40+
ctxs_.emplace_back(new platform::CUDADeviceContext(
41+
boost::get<platform::CUDAPlace>(place_)));
42+
#endif
43+
}
44+
}
45+
46+
start_thread();
47+
}
48+
49+
void start_thread() {
50+
buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize);
51+
std::thread prefetch([this] { PrefetchThreadFunc(); });
3252
prefetch.detach();
3353
}
3454

@@ -42,7 +62,10 @@ class DoubleBufferReader : public framework::DecoratedReader {
4262
private:
4363
void PrefetchThreadFunc();
4464

45-
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
65+
framework::Channel<Item>* buffer_;
66+
platform::Place place_;
67+
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
68+
mutable Item local_buffer_;
4669
};
4770

4871
class CreateDoubleBufferReaderOp : public framework::OperatorBase {
@@ -56,7 +79,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
5679
->Get<framework::ReaderHolder>();
5780
auto* out = scope.FindVar(Output("Out"))
5881
->template GetMutable<framework::ReaderHolder>();
59-
out->Reset(new DoubleBufferReader(underlying_reader.Get()));
82+
83+
auto place_str = Attr<std::string>("place");
84+
platform::Place place;
85+
if (place_str == "CPU") {
86+
place = platform::CPUPlace();
87+
} else {
88+
std::istringstream sin(place_str);
89+
sin.seekg(std::string("CUDA:").size(), std::ios::beg);
90+
size_t num;
91+
sin >> num;
92+
place = platform::CUDAPlace(static_cast<int>(num));
93+
}
94+
95+
out->Reset(new DoubleBufferReader(underlying_reader.Get(), place));
6096
}
6197
};
6298

@@ -71,44 +107,73 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
71107
It launches another thread to execute the 'underlying reader' asynchronously,
72108
which prevents reading process from blocking subsequent training.
73109
)DOC");
110+
std::unordered_set<std::string> enum_range;
111+
constexpr size_t kMaxCUDADevs = 128;
112+
for (size_t i = 0; i < kMaxCUDADevs; ++i) {
113+
enum_range.insert(string::Sprintf("CUDA:%d", i));
114+
}
115+
enum_range.insert("CPU");
116+
AddAttr<std::string>("place", "The double buffer place, default is CPU")
117+
.SetDefault("CPU")
118+
.InEnum({enum_range});
74119
}
75120
};
76121

77122
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
78-
out->clear();
79-
buffer_->Receive(out);
123+
if (local_buffer_.payloads_.empty()) {
124+
buffer_->Receive(&local_buffer_);
125+
}
126+
127+
*out = local_buffer_.payloads_;
128+
local_buffer_.payloads_.clear();
129+
if (local_buffer_.ctx_) {
130+
local_buffer_.ctx_->Wait();
131+
}
80132
}
81133

82134
void DoubleBufferReader::ReInit() {
83135
reader_->ReInit();
84136
buffer_->Close();
85-
// The existing prefetch thread will terminate for the buffer_ is closed.
86-
buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>(
87-
kDoubleBufferSize);
88-
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
89-
prefetch.detach();
137+
start_thread();
90138
}
91139

92140
void DoubleBufferReader::PrefetchThreadFunc() {
93141
VLOG(5) << "A new prefetch thread starts.";
94-
while (true) {
95-
std::vector<framework::LoDTensor> batch;
96-
reader_->ReadNext(&batch);
97-
if (batch.empty()) {
98-
// EOF
99-
buffer_->Close();
100-
VLOG(5) << "Reached the end of the file. The prefetch thread terminates.";
101-
break;
142+
size_t gpu_ctx_offset = 0;
143+
while (reader_->HasNext()) {
144+
Item batch;
145+
reader_->ReadNext(&batch.payloads_);
146+
if (platform::is_gpu_place(place_)) {
147+
std::vector<framework::LoDTensor> gpu_batch;
148+
auto& gpu_ctx = this->ctxs_[gpu_ctx_offset++];
149+
gpu_ctx_offset %= this->ctxs_.size();
150+
gpu_batch.resize(batch.payloads_.size());
151+
for (size_t i = 0; i < batch.payloads_.size(); ++i) {
152+
framework::TensorCopy(batch.payloads_[i], place_, *gpu_ctx,
153+
&gpu_batch[i]);
154+
gpu_batch[i].set_lod(batch.payloads_[i].lod());
155+
}
156+
batch.ctx_ = gpu_ctx.get();
157+
std::swap(gpu_batch, batch.payloads_);
102158
}
159+
103160
if (!buffer_->Send(&batch)) {
104161
VLOG(5) << "WARNING: The double buffer channel has been closed. The "
105162
"prefetch thread terminates.";
106163
break;
107164
}
108165
}
166+
buffer_->Close();
109167
}
110168

111-
bool DoubleBufferReader::HasNext() const { PADDLE_THROW("Not Implemented"); }
169+
bool DoubleBufferReader::HasNext() const {
170+
if (local_buffer_.payloads_.empty()) {
171+
bool ok = buffer_->Receive(&local_buffer_);
172+
return ok;
173+
} else {
174+
return true;
175+
}
176+
}
112177

113178
} // namespace reader
114179
} // namespace operators

paddle/fluid/operators/reader/create_random_data_generator_op.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ namespace operators {
1919
namespace reader {
2020

2121
template <typename T>
22-
class RandomDataGenerator : public framework::FileReader {
22+
class RandomDataGenerator : public framework::ReaderBase {
2323
public:
2424
RandomDataGenerator(const std::vector<framework::DDim>& shapes, float min,
2525
float max)
26-
: FileReader(shapes), min_(min), max_(max) {
26+
: framework::ReaderBase(), min_(min), max_(max), shapes_(shapes) {
2727
PADDLE_ENFORCE_LE(
2828
min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max);
2929
unsigned int seed = std::random_device()();
@@ -59,6 +59,7 @@ class RandomDataGenerator : public framework::FileReader {
5959
float max_;
6060
std::minstd_rand engine_;
6161
std::uniform_real_distribution<float> dist_;
62+
std::vector<framework::DDim> shapes_;
6263
};
6364

6465
template <typename T>

paddle/fluid/operators/reader/create_recordio_file_reader_op.cc

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,22 @@ namespace operators {
2020
namespace reader {
2121
class RecordIOFileReader : public framework::FileReader {
2222
public:
23-
RecordIOFileReader(const std::string& filename,
24-
const std::vector<framework::DDim>& shapes)
25-
: FileReader(shapes),
23+
explicit RecordIOFileReader(const std::string& filename,
24+
const std::vector<framework::DDim>& dims)
25+
: FileReader(dims),
2626
scanner_(filename),
2727
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
2828
platform::CPUPlace())) {}
2929

30-
void ReadNext(std::vector<framework::LoDTensor>* out) override {
31-
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
32-
}
33-
3430
bool HasNext() const override { return scanner_.HasNext(); }
3531

3632
void ReInit() override { scanner_.Reset(); }
3733

34+
protected:
35+
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
36+
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
37+
}
38+
3839
private:
3940
recordio::Scanner scanner_;
4041
const platform::DeviceContext& dev_ctx_;
@@ -54,12 +55,12 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
5455
int(shape_concat.size()),
5556
"The accumulate of all ranks should be equal to the "
5657
"shape concat's length.");
57-
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
5858
std::string filename = Attr<std::string>("filename");
5959

6060
auto* out = scope.FindVar(Output("Out"))
6161
->template GetMutable<framework::ReaderHolder>();
62-
out->Reset(new RecordIOFileReader(filename, shapes));
62+
out->Reset(
63+
new RecordIOFileReader(filename, RestoreShapes(shape_concat, ranks)));
6364
}
6465
};
6566

@@ -85,3 +86,5 @@ namespace reader = paddle::operators::reader;
8586
REGISTER_FILE_READER_OPERATOR(create_recordio_file_reader,
8687
reader::CreateRecordIOReaderOp,
8788
reader::CreateRecordIOReaderOpMaker);
89+
90+
REGISTER_FILE_READER(recordio, reader::RecordIOFileReader);

0 commit comments

Comments
 (0)