Skip to content

Commit cd07c0f

Browse files
authored
Merge pull request #9259 from JiayiFeng/dev_MultiEpochReader
Multi-pass reader
2 parents c8e66e8 + 809530f commit cd07c0f

File tree

4 files changed

+173
-1
lines changed

4 files changed

+173
-1
lines changed

paddle/fluid/operators/reader/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,6 @@ reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
2121
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
2222
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
2323
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
24+
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
2425
# Export local libraries to parent
2526
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/detail/safe_ref.h"
16+
#include "paddle/fluid/operators/reader/reader_op_registry.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
namespace reader {
21+
22+
class MultiPassReader : public framework::DecoratedReader {
23+
public:
24+
MultiPassReader(ReaderBase* reader, int pass_num)
25+
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
26+
27+
void ReadNext(std::vector<framework::LoDTensor>* out) override {
28+
if (!HasNext()) {
29+
PADDLE_THROW("There is no next data!");
30+
}
31+
reader_->ReadNext(out);
32+
}
33+
34+
bool HasNext() const override {
35+
if (reader_->HasNext()) {
36+
return true;
37+
} else {
38+
++pass_count_;
39+
if (pass_count_ >= pass_num_) {
40+
return false;
41+
} else {
42+
reader_->ReInit();
43+
return true;
44+
}
45+
}
46+
}
47+
48+
void ReInit() override {
49+
pass_count_ = 0;
50+
reader_->ReInit();
51+
}
52+
53+
private:
54+
int pass_num_;
55+
mutable int pass_count_;
56+
};
57+
58+
class CreateMultiPassReaderOp : public framework::OperatorBase {
59+
public:
60+
using framework::OperatorBase::OperatorBase;
61+
62+
private:
63+
void RunImpl(const framework::Scope& scope,
64+
const platform::Place& dev_place) const override {
65+
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
66+
->Get<framework::ReaderHolder>();
67+
auto& out = detail::Ref(scope.FindVar(Output("Out")));
68+
int pass_num = Attr<int>("pass_num");
69+
out.GetMutable<framework::ReaderHolder>()->Reset(
70+
new MultiPassReader(underlying_reader.Get(), pass_num));
71+
}
72+
};
73+
74+
class CreateMultiPassReaderOpMaker : public DecoratedReaderMakerBase {
75+
public:
76+
CreateMultiPassReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
77+
: DecoratedReaderMakerBase(op_proto, op_checker) {
78+
AddAttr<int>("pass_num", "The number of pass to run.").GreaterThan(0);
79+
AddComment(R"DOC(
80+
CreateMultiPassReader Operator
81+
82+
This operator creates a multi-pass reader. A multi-pass reader
83+
is used to yield data for several pass training continuously.
84+
It takes the the number of pass to run as one of its attributes
85+
('pass_num'), and maintains a pass counter to record how many
86+
passes it has completed. When the underlying reader reach the EOF,
87+
the multi-pass reader checks whether it has completed training
88+
of the given number of pass. If not, the underlying reader will
89+
be re-initialized and starts a new pass automatically.
90+
)DOC");
91+
}
92+
};
93+
94+
} // namespace reader
95+
} // namespace operators
96+
} // namespace paddle
97+
98+
namespace ops = paddle::operators::reader;
99+
REGISTER_DECORATED_READER_OPERATOR(create_multi_pass_reader,
100+
ops::CreateMultiPassReaderOp,
101+
ops::CreateMultiPassReaderOpMaker);

python/paddle/fluid/layers/io.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
__all__ = [
2323
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
2424
'open_files', 'read_file', 'create_shuffle_reader',
25-
'create_double_buffer_reader'
25+
'create_double_buffer_reader', 'create_multi_pass_reader'
2626
]
2727

2828

@@ -345,6 +345,11 @@ def create_double_buffer_reader(reader, place=None):
345345
attrs)
346346

347347

348+
def create_multi_pass_reader(reader, pass_num):
349+
return __create_decorated_reader__('create_multi_pass_reader', reader,
350+
{'pass_num': int(pass_num)})
351+
352+
348353
def read_file(file_obj):
349354
helper = LayerHelper('read_file')
350355
out = [
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import paddle.fluid as fluid
18+
import paddle.v2 as paddle
19+
import paddle.v2.dataset.mnist as mnist
20+
21+
22+
class TestMultipleReader(unittest.TestCase):
23+
def setUp(self):
24+
self.batch_size = 64
25+
self.pass_num = 3
26+
# Convert mnist to recordio file
27+
with fluid.program_guard(fluid.Program(), fluid.Program()):
28+
data_file = paddle.batch(mnist.train(), batch_size=self.batch_size)
29+
feeder = fluid.DataFeeder(
30+
feed_list=[
31+
fluid.layers.data(
32+
name='image', shape=[784]),
33+
fluid.layers.data(
34+
name='label', shape=[1], dtype='int64'),
35+
],
36+
place=fluid.CPUPlace())
37+
self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file(
38+
'./mnist.recordio', data_file, feeder)
39+
40+
def test_main(self):
41+
with fluid.program_guard(fluid.Program(), fluid.Program()):
42+
data_file = fluid.layers.open_recordio_file(
43+
filename='./mnist.recordio',
44+
shapes=[(-1, 784), (-1, 1)],
45+
lod_levels=[0, 0],
46+
dtypes=['float32', 'int64'])
47+
data_file = fluid.layers.create_multi_pass_reader(
48+
reader=data_file, pass_num=self.pass_num)
49+
img, label = fluid.layers.read_file(data_file)
50+
51+
if fluid.core.is_compiled_with_cuda():
52+
place = fluid.CUDAPlace(0)
53+
else:
54+
place = fluid.CPUPlace()
55+
56+
exe = fluid.Executor(place)
57+
exe.run(fluid.default_startup_program())
58+
59+
batch_count = 0
60+
while not data_file.eof():
61+
img_val, = exe.run(fetch_list=[img])
62+
batch_count += 1
63+
self.assertLessEqual(img_val.shape[0], self.batch_size)
64+
data_file.reset()
65+
self.assertEqual(batch_count, self.num_batch * self.pass_num)

0 commit comments

Comments
 (0)