Skip to content

Commit 6d6f49c

Browse files
committed
Merge remote-tracking branch 'yuyang/feature/decorated_reader_chain' into dev_reader_ResetAll
2 parents 611716e + 62c1133 commit 6d6f49c

13 files changed

+138
-19
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
2727
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
2828

2929
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
30+
cc_test(reader_test SRCS reader_test.cc DEPS reader)
3031

3132
cc_test(variable_test SRCS variable_test.cc)
3233

paddle/fluid/framework/reader.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/reader.h"
16+
#include <deque>
1617

1718
namespace paddle {
1819
namespace framework {
@@ -23,6 +24,33 @@ void ReaderBase::ReadNext(std::vector<LoDTensor> *out) {
2324
ReadNextImpl(out);
2425
}
2526

27+
void ReaderBase::InsertDecoratedReader(
28+
const std::shared_ptr<ReaderBase> &decorated_reader) {
29+
std::lock_guard<std::mutex> guard(mu_));
30+
decorated_readers_.emplace_back(decorated_reader);
31+
}
32+
33+
std::unordered_set<ReaderBase *> ReaderBase::GetEndPoints() {
34+
std::unordered_set<ReaderBase *> result;
35+
std::deque<ReaderBase *> queue;
36+
queue.emplace_back(this);
37+
while (!queue.empty()) { // BFS search
38+
auto *front = queue.front();
39+
queue.pop_front();
40+
if (front->decorated_readers_.empty()) {
41+
result.emplace(front);
42+
} else {
43+
for (auto &reader : front->decorated_readers_) {
44+
if (auto *reader_ptr = reader.lock().get()) {
45+
queue.emplace_back(reader_ptr);
46+
}
47+
}
48+
}
49+
}
50+
51+
return result;
52+
}
53+
2654
void ReaderBase::Shutdown() {
2755
std::lock_guard<std::mutex> lock(mu_);
2856
if (status_ != ReaderStatus::kStopped) {

paddle/fluid/framework/reader.h

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include <memory>
18+
#include <unordered_set>
1819
#include <vector>
1920

2021
#include "paddle/fluid/framework/ddim.h"
@@ -34,6 +35,10 @@ class ReaderBase {
3435

3536
void Start();
3637

38+
// Return the readers which are the end of decorating chain. Basically
39+
// they are readers just before read op.
40+
std::unordered_set<ReaderBase*> GetEndPoints();
41+
3742
virtual ~ReaderBase();
3843

3944
protected:
@@ -46,15 +51,29 @@ class ReaderBase {
4651
ReaderStatus status_{kRunning};
4752

4853
mutable std::mutex mu_;
54+
55+
private:
56+
friend class DecoratedReader;
57+
// These methods can be only invoked inside DecoratedReader to record the
58+
// decorating chain.
59+
void InsertDecoratedReader(
60+
const std::shared_ptr<ReaderBase>& decorated_reader);
61+
// A set of which readers that decorated this reader.
62+
std::vector<std::weak_ptr<ReaderBase>> decorated_readers_;
4963
};
5064

51-
class DecoratedReader : public ReaderBase {
65+
class DecoratedReader : public ReaderBase,
66+
public std::enable_shared_from_this<DecoratedReader> {
5267
public:
5368
explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
5469
: ReaderBase(), reader_(reader) {
5570
PADDLE_ENFORCE_NOT_NULL(reader_);
5671
}
5772

73+
void RegisterDecorateChain() {
74+
reader_->InsertDecoratedReader(shared_from_this());
75+
}
76+
5877
protected:
5978
void ShutdownImpl() override { reader_->Shutdown(); }
6079

@@ -70,9 +89,14 @@ class FileReader : public ReaderBase {};
7089
// making it easier to access different type reader in Variables.
7190
class ReaderHolder {
7291
public:
73-
void Reset(ReaderBase* reader) { reader_.reset(reader); }
92+
template <typename T>
93+
void Reset(const std::shared_ptr<T>& reader) {
94+
auto reader_base = std::dynamic_pointer_cast<ReaderBase>(reader);
95+
PADDLE_ENFORCE_NOT_NULL(reader_base);
96+
reader_ = reader_base;
97+
}
7498

75-
std::shared_ptr<ReaderBase> Get() const { return reader_; }
99+
const std::shared_ptr<ReaderBase>& Get() const { return reader_; }
76100

77101
void ReadNext(std::vector<LoDTensor>* out) {
78102
PADDLE_ENFORCE_NOT_NULL(reader_);
@@ -93,9 +117,18 @@ class ReaderHolder {
93117
reader_->Start();
94118
}
95119

120+
operator const std::shared_ptr<ReaderBase>&() const { return this->reader_; }
121+
96122
private:
97123
std::shared_ptr<ReaderBase> reader_;
98124
};
99125

126+
template <typename T, typename... ARGS>
127+
inline std::shared_ptr<DecoratedReader> MakeDecoratedReader(ARGS&&... args) {
128+
std::shared_ptr<DecoratedReader> reader(new T(std::forward<ARGS>(args)...));
129+
reader->RegisterDecorateChain();
130+
return reader;
131+
}
132+
100133
} // namespace framework
101134
} // namespace paddle

paddle/fluid/framework/reader_test.cc

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/reader.h"
16+
#include <memory>
17+
#include "gtest/gtest.h"
18+
19+
class StubDecoratedReader : public paddle::framework::DecoratedReader {
20+
public:
21+
explicit StubDecoratedReader(const std::shared_ptr<ReaderBase> &reader)
22+
: DecoratedReader(reader) {}
23+
24+
void ReadNext(std::vector<paddle::framework::LoDTensor> *out) override {}
25+
};
26+
27+
class StubRootReader : public paddle::framework::ReaderBase {
28+
public:
29+
void ReadNext(std::vector<paddle::framework::LoDTensor> *out) override {}
30+
void ReInit() override {}
31+
};
32+
33+
TEST(READER, decorate_chain) {
34+
auto root = std::make_shared<StubRootReader>();
35+
auto end_point1 =
36+
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
37+
auto end_point2 =
38+
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
39+
40+
{
41+
auto endpoints = root->GetEndPoints();
42+
ASSERT_EQ(endpoints.size(), 2U);
43+
ASSERT_NE(endpoints.count(end_point1.get()), 0);
44+
ASSERT_NE(endpoints.count(end_point2.get()), 0);
45+
}
46+
47+
{
48+
auto end_point3 =
49+
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
50+
ASSERT_EQ(root->GetEndPoints().size(), 3U);
51+
}
52+
{ ASSERT_EQ(root->GetEndPoints().size(), 2U); }
53+
}

paddle/fluid/operators/reader/create_batch_reader_op.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ class CreateBatchReaderOp : public framework::OperatorBase {
5050
}
5151
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
5252
->Get<framework::ReaderHolder>();
53-
out->Reset(new BatchReader(underlying_reader.Get(), Attr<int>("batch_size"),
54-
Attr<bool>("discard_leftover")));
53+
out->Reset(framework::MakeDecoratedReader<BatchReader>(
54+
underlying_reader, Attr<int>("batch_size"),
55+
Attr<bool>("discard_leftover")));
5556
}
5657
};
5758

paddle/fluid/operators/reader/create_custom_reader_op.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@ class CreateCustomReaderOp : public framework::OperatorBase {
6060
}
6161
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
6262
->Get<framework::ReaderHolder>();
63-
out->Reset(
64-
new CustomReader(underlying_reader.Get(), *sub_block,
65-
Attr<std::vector<std::string>>("source_var_names"),
66-
Attr<std::vector<std::string>>("sink_var_names")));
63+
out->Reset(framework::MakeDecoratedReader<CustomReader>(
64+
underlying_reader, *sub_block,
65+
Attr<std::vector<std::string>>("source_var_names"),
66+
Attr<std::vector<std::string>>("sink_var_names")));
6767
}
6868
};
6969

paddle/fluid/operators/reader/create_double_buffer_reader_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
118118
place = platform::CUDAPlace(static_cast<int>(num));
119119
}
120120

121-
out->Reset(new DoubleBufferReader(underlying_reader.Get(), place));
121+
out->Reset(framework::MakeDecoratedReader<DoubleBufferReader>(
122+
underlying_reader, place));
122123
}
123124
};
124125

paddle/fluid/operators/reader/create_multi_pass_reader_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
5959
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
6060
->Get<framework::ReaderHolder>();
6161
int pass_num = Attr<int>("pass_num");
62-
out->Reset(new MultiPassReader(underlying_reader.Get(), pass_num));
62+
out->Reset(framework::MakeDecoratedReader<MultiPassReader>(
63+
underlying_reader, pass_num));
6364
}
6465
};
6566

paddle/fluid/operators/reader/create_py_reader_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class CreatePyReaderOp : public framework::OperatorBase {
6363
auto* queue_holder =
6464
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
6565

66-
out->Reset(new PyReader(queue_holder->GetQueue()));
66+
out->Reset(std::make_shared<PyReader>(queue_holder->GetQueue()));
6767
}
6868
};
6969

paddle/fluid/operators/reader/create_random_data_generator_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
7777
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
7878
auto* out = scope.FindVar(Output("Out"))
7979
->template GetMutable<framework::ReaderHolder>();
80-
out->Reset(new RandomDataGenerator<T>(shapes, Attr<float>("low"),
81-
Attr<float>("high")));
80+
out->Reset(std::make_shared<RandomDataGenerator<T>>(
81+
shapes, Attr<float>("low"), Attr<float>("high")));
8282
}
8383
};
8484

0 commit comments

Comments
 (0)