Skip to content

Commit 26ae611

Browse files
authored
Merge pull request #12051 from JiayiFeng/dev_reader_ResetAll
[WIP] Dev reader reset all
2 parents 10fbb83 + d55919c commit 26ae611

19 files changed

+283
-219
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
2727
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
2828

2929
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
30+
cc_test(reader_test SRCS reader_test.cc DEPS reader)
3031

3132
cc_test(variable_test SRCS variable_test.cc)
3233

paddle/fluid/framework/reader.cc

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,61 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/reader.h"
16+
#include <deque>
1617

1718
namespace paddle {
1819
namespace framework {
19-
ReaderBase::~ReaderBase() {}
2020

21-
FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}
22-
23-
void FileReader::ReadNext(std::vector<LoDTensor> *out) {
21+
void ReaderBase::ReadNext(std::vector<LoDTensor> *out) {
22+
std::lock_guard<std::mutex> lock(mu_);
23+
PADDLE_ENFORCE_EQ(status_, ReaderStatus::kRunning);
2424
ReadNextImpl(out);
25-
if (out->empty()) {
26-
return;
27-
}
25+
}
2826

29-
PADDLE_ENFORCE_EQ(out->size(), dims_.size());
30-
for (size_t i = 0; i < dims_.size(); ++i) {
31-
auto &actual = (*out)[i].dims();
32-
auto &expect = dims_[i];
27+
void ReaderBase::InsertDecoratedReader(
28+
const std::shared_ptr<ReaderBase> &decorated_reader) {
29+
std::lock_guard<std::mutex> guard(mu_);
30+
decorated_readers_.emplace_back(decorated_reader);
31+
}
3332

34-
PADDLE_ENFORCE_EQ(actual.size(), expect.size());
35-
for (int j = 0; j < actual.size(); ++j) {
36-
// PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1);
33+
std::unordered_set<ReaderBase *> ReaderBase::GetEndPoints() {
34+
std::unordered_set<ReaderBase *> result;
35+
std::deque<ReaderBase *> queue;
36+
queue.emplace_back(this);
37+
while (!queue.empty()) { // BFS search
38+
auto *front = queue.front();
39+
queue.pop_front();
40+
if (front->decorated_readers_.empty()) {
41+
result.emplace(front);
42+
} else {
43+
for (auto &reader : front->decorated_readers_) {
44+
if (auto *reader_ptr = reader.lock().get()) {
45+
queue.emplace_back(reader_ptr);
46+
}
47+
}
3748
}
3849
}
50+
51+
return result;
3952
}
53+
54+
void ReaderBase::Shutdown() {
55+
std::lock_guard<std::mutex> lock(mu_);
56+
if (status_ != ReaderStatus::kStopped) {
57+
ShutdownImpl();
58+
status_ = ReaderStatus::kStopped;
59+
}
60+
}
61+
62+
void ReaderBase::Start() {
63+
std::lock_guard<std::mutex> lock(mu_);
64+
if (status_ != ReaderStatus::kRunning) {
65+
StartImpl();
66+
status_ = ReaderStatus::kRunning;
67+
}
68+
}
69+
70+
ReaderBase::~ReaderBase() { Shutdown(); }
71+
4072
} // namespace framework
4173
} // namespace paddle

paddle/fluid/framework/reader.h

Lines changed: 76 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include <memory>
18+
#include <unordered_set>
1819
#include <vector>
1920

2021
#include "paddle/fluid/framework/ddim.h"
@@ -24,61 +25,116 @@
2425
namespace paddle {
2526
namespace framework {
2627

28+
enum ReaderStatus { kRunning, kStopped };
29+
2730
class ReaderBase {
2831
public:
29-
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
32+
void ReadNext(std::vector<LoDTensor>* out);
33+
34+
void Shutdown();
3035

31-
virtual void ReInit() = 0;
36+
void Start();
37+
38+
// Return the readers which are the end of decorating chain. Basically
39+
// they are readers just before read op.
40+
std::unordered_set<ReaderBase*> GetEndPoints();
3241

3342
virtual ~ReaderBase();
43+
44+
protected:
45+
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
46+
47+
virtual void ShutdownImpl() {}
48+
49+
virtual void StartImpl() {}
50+
51+
ReaderStatus status_{kRunning};
52+
53+
mutable std::mutex mu_;
54+
55+
private:
56+
friend class DecoratedReader;
57+
// These methods can be only invoked inside DecoratedReader to record the
58+
// decorating chain.
59+
void InsertDecoratedReader(
60+
const std::shared_ptr<ReaderBase>& decorated_reader);
61+
// A set of which readers that decorated this reader.
62+
std::vector<std::weak_ptr<ReaderBase>> decorated_readers_;
3463
};
3564

36-
class DecoratedReader : public ReaderBase {
65+
class DecoratedReader : public ReaderBase,
66+
public std::enable_shared_from_this<DecoratedReader> {
3767
public:
3868
explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
3969
: ReaderBase(), reader_(reader) {
4070
PADDLE_ENFORCE_NOT_NULL(reader_);
4171
}
4272

43-
void ReInit() override { reader_->ReInit(); }
73+
void RegisterDecorateChain() {
74+
reader_->InsertDecoratedReader(shared_from_this());
75+
}
4476

4577
protected:
46-
std::shared_ptr<ReaderBase> reader_;
47-
};
48-
49-
class FileReader : public ReaderBase {
50-
public:
51-
explicit FileReader(const std::vector<DDim>& dims);
52-
53-
void ReadNext(std::vector<LoDTensor>* out) override;
78+
void ShutdownImpl() override { reader_->Shutdown(); }
5479

55-
protected:
56-
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
80+
void StartImpl() override { reader_->Start(); }
5781

58-
private:
59-
std::vector<DDim> dims_;
82+
std::shared_ptr<ReaderBase> reader_;
6083
};
6184

85+
// FileReader is just a conceptual class.
86+
class FileReader : public ReaderBase {};
87+
6288
// The ReaderHolder is used as reader' unified wrapper,
6389
// making it easier to access different type reader in Variables.
6490
class ReaderHolder {
6591
public:
66-
void Reset(ReaderBase* reader) { reader_.reset(reader); }
92+
template <typename T>
93+
void Reset(const std::shared_ptr<T>& reader) {
94+
auto reader_base = std::dynamic_pointer_cast<ReaderBase>(reader);
95+
PADDLE_ENFORCE_NOT_NULL(reader_base);
96+
reader_ = reader_base;
97+
}
6798

68-
std::shared_ptr<ReaderBase> Get() const { return reader_; }
99+
const std::shared_ptr<ReaderBase>& Get() const { return reader_; }
69100

70101
void ReadNext(std::vector<LoDTensor>* out) {
71102
PADDLE_ENFORCE_NOT_NULL(reader_);
72103
reader_->ReadNext(out);
73104
}
74-
void ReInit() {
105+
106+
void ResetAll() {
107+
auto end_readers = reader_->GetEndPoints();
108+
for (auto* reader : end_readers) {
109+
reader->Shutdown();
110+
}
111+
for (auto* reader : end_readers) {
112+
reader->Start();
113+
}
114+
}
115+
116+
void Shutdown() {
117+
PADDLE_ENFORCE_NOT_NULL(reader_);
118+
reader_->Shutdown();
119+
}
120+
121+
void Start() {
75122
PADDLE_ENFORCE_NOT_NULL(reader_);
76-
reader_->ReInit();
123+
reader_->Start();
77124
}
78125

126+
operator const std::shared_ptr<ReaderBase>&() const { return this->reader_; }
127+
79128
private:
80129
std::shared_ptr<ReaderBase> reader_;
81130
};
82131

132+
template <typename T, typename... ARGS>
133+
inline std::shared_ptr<DecoratedReader> MakeDecoratedReader(ARGS&&... args) {
134+
std::shared_ptr<DecoratedReader> reader(new T(std::forward<ARGS>(args)...));
135+
reader->RegisterDecorateChain();
136+
return reader;
137+
}
138+
83139
} // namespace framework
84140
} // namespace paddle

paddle/fluid/framework/reader_test.cc

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/reader.h"
16+
#include <memory>
17+
#include "gtest/gtest.h"
18+
19+
class StubDecoratedReader : public paddle::framework::DecoratedReader {
20+
public:
21+
explicit StubDecoratedReader(const std::shared_ptr<ReaderBase> &reader)
22+
: DecoratedReader(reader) {}
23+
24+
void ReadNextImpl(std::vector<paddle::framework::LoDTensor> *out) override {}
25+
};
26+
27+
class StubRootReader : public paddle::framework::ReaderBase {
28+
public:
29+
void ReadNextImpl(std::vector<paddle::framework::LoDTensor> *out) override {}
30+
};
31+
32+
TEST(READER, decorate_chain) {
33+
auto root = std::make_shared<StubRootReader>();
34+
auto end_point1 =
35+
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
36+
auto end_point2 =
37+
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
38+
39+
{
40+
auto endpoints = root->GetEndPoints();
41+
ASSERT_EQ(endpoints.size(), 2U);
42+
ASSERT_NE(endpoints.count(end_point1.get()), 0);
43+
ASSERT_NE(endpoints.count(end_point2.get()), 0);
44+
}
45+
46+
{
47+
auto end_point3 =
48+
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
49+
ASSERT_EQ(root->GetEndPoints().size(), 3U);
50+
}
51+
{ ASSERT_EQ(root->GetEndPoints().size(), 2U); }
52+
}

paddle/fluid/operators/reader/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
2222
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
2323
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
2424
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
25-
reader_library(create_threaded_reader_op SRCS create_threaded_reader_op.cc)
2625
reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc)
2726
reader_library(create_py_reader_op SRCS create_py_reader_op.cc)
2827

paddle/fluid/operators/reader/create_batch_reader_op.cc

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,19 @@ namespace reader {
2020

2121
class BatchReader : public framework::DecoratedReader {
2222
public:
23-
BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size)
24-
: DecoratedReader(reader), batch_size_(batch_size) {
23+
BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size,
24+
bool discard_leftover)
25+
: DecoratedReader(reader),
26+
batch_size_(batch_size),
27+
discard_leftover_(discard_leftover) {
2528
buffer_.reserve(batch_size_);
2629
}
2730

28-
void ReadNext(std::vector<framework::LoDTensor>* out) override;
31+
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
2932

3033
private:
3134
int batch_size_;
35+
bool discard_leftover_;
3236
std::vector<std::vector<framework::LoDTensor>> buffer_;
3337
};
3438

@@ -46,8 +50,9 @@ class CreateBatchReaderOp : public framework::OperatorBase {
4650
}
4751
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
4852
->Get<framework::ReaderHolder>();
49-
out->Reset(
50-
new BatchReader(underlying_reader.Get(), Attr<int>("batch_size")));
53+
out->Reset(framework::MakeDecoratedReader<BatchReader>(
54+
underlying_reader, Attr<int>("batch_size"),
55+
Attr<bool>("discard_leftover")));
5156
}
5257
};
5358

@@ -57,6 +62,10 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
5762
AddAttr<int>("batch_size",
5863
"How many instances the batch reader yields each time.")
5964
.GreaterThan(0);
65+
AddAttr<bool>("discard_leftover",
66+
"If true, the leftover instances that are not enough for a "
67+
"new batch will be discarded.")
68+
.SetDefault(true);
6069
AddComment(R"DOC(
6170
CreateBatchReader Operator
6271
@@ -66,7 +75,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
6675
}
6776
};
6877

69-
void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) {
78+
void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
7079
buffer_.clear();
7180
buffer_.reserve(batch_size_);
7281
for (int i = 0; i < batch_size_; ++i) {
@@ -77,6 +86,9 @@ void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) {
7786
break;
7887
}
7988
}
89+
if (discard_leftover_ && buffer_.size() < batch_size_) {
90+
buffer_.clear();
91+
}
8092
// Concat instances
8193
out->clear();
8294
if (buffer_.empty()) {

paddle/fluid/operators/reader/create_custom_reader_op.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class CustomReader : public framework::DecoratedReader {
3333
source_var_names_(source_var_names),
3434
sink_var_names_(sink_var_names) {}
3535

36-
void ReadNext(std::vector<framework::LoDTensor>* out) override;
36+
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
3737

3838
private:
3939
const framework::ProgramDesc program_;
@@ -60,10 +60,10 @@ class CreateCustomReaderOp : public framework::OperatorBase {
6060
}
6161
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
6262
->Get<framework::ReaderHolder>();
63-
out->Reset(
64-
new CustomReader(underlying_reader.Get(), *sub_block,
65-
Attr<std::vector<std::string>>("source_var_names"),
66-
Attr<std::vector<std::string>>("sink_var_names")));
63+
out->Reset(framework::MakeDecoratedReader<CustomReader>(
64+
underlying_reader, *sub_block,
65+
Attr<std::vector<std::string>>("source_var_names"),
66+
Attr<std::vector<std::string>>("sink_var_names")));
6767
}
6868
};
6969

@@ -143,7 +143,7 @@ class CustomReaderInferVarType : public framework::VarTypeInference {
143143
}
144144
};
145145

146-
void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
146+
void CustomReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
147147
out->clear();
148148
std::vector<framework::LoDTensor> underlying_outs;
149149
reader_->ReadNext(&underlying_outs);

0 commit comments

Comments
 (0)