Skip to content

Commit 3874c38

Browse files
authored
Merge pull request #9596 from JiayiFeng/update_reader
Update reader
2 parents 2eddafe + 38ba7e5 commit 3874c38

File tree

7 files changed

+75
-19
lines changed

7 files changed

+75
-19
lines changed

paddle/fluid/operators/reader/create_batch_reader_op.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,13 @@ class CreateBatchReaderOp : public framework::OperatorBase {
3939
private:
4040
void RunImpl(const framework::Scope& scope,
4141
const platform::Place& dev_place) const override {
42-
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
43-
->Get<framework::ReaderHolder>();
4442
auto* out = scope.FindVar(Output("Out"))
4543
->template GetMutable<framework::ReaderHolder>();
44+
if (out->Get() != nullptr) {
45+
return;
46+
}
47+
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
48+
->Get<framework::ReaderHolder>();
4649
out->Reset(
4750
new BatchReader(underlying_reader.Get(), Attr<int>("batch_size")));
4851
}

paddle/fluid/operators/reader/create_double_buffer_reader_op.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,13 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
9999
private:
100100
void RunImpl(const framework::Scope& scope,
101101
const platform::Place& dev_place) const override {
102-
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
103-
->Get<framework::ReaderHolder>();
104102
auto* out = scope.FindVar(Output("Out"))
105103
->template GetMutable<framework::ReaderHolder>();
104+
if (out->Get() != nullptr) {
105+
return;
106+
}
107+
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
108+
->Get<framework::ReaderHolder>();
106109

107110
auto place_str = Attr<std::string>("place");
108111
platform::Place place;

paddle/fluid/operators/reader/create_multi_pass_reader_op.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,15 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
6262
private:
6363
void RunImpl(const framework::Scope& scope,
6464
const platform::Place& dev_place) const override {
65+
auto* out = detail::Ref(scope.FindVar(Output("Out")))
66+
.GetMutable<framework::ReaderHolder>();
67+
if (out->Get() != nullptr) {
68+
return;
69+
}
6570
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
6671
->Get<framework::ReaderHolder>();
67-
auto& out = detail::Ref(scope.FindVar(Output("Out")));
6872
int pass_num = Attr<int>("pass_num");
69-
out.GetMutable<framework::ReaderHolder>()->Reset(
70-
new MultiPassReader(underlying_reader.Get(), pass_num));
73+
out->Reset(new MultiPassReader(underlying_reader.Get(), pass_num));
7174
}
7275
};
7376

paddle/fluid/operators/reader/create_shuffle_reader_op.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,14 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
8080
private:
8181
void RunImpl(const framework::Scope& scope,
8282
const platform::Place& dev_place) const override {
83+
auto* out = detail::Ref(scope.FindVar(Output("Out")))
84+
.GetMutable<framework::ReaderHolder>();
85+
if (out->Get() != nullptr) {
86+
return;
87+
}
8388
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
8489
->Get<framework::ReaderHolder>();
85-
auto& var = detail::Ref(scope.FindVar(Output("Out")));
86-
var.GetMutable<framework::ReaderHolder>()->Reset(
90+
out->Reset(
8791
new ShuffleReader(underlying_reader.Get(),
8892
static_cast<size_t>(Attr<int>("buffer_size"))));
8993
}

python/paddle/fluid/framework.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,20 @@ def block_attr(self, name):
640640
"""
641641
return self.desc.block_attr(name)
642642

643+
def all_attrs(self):
644+
"""
645+
Get the attribute dict
646+
Returns(dict): The Operator's attribute dict
647+
"""
648+
attr_names = self.attr_names
649+
attr_map = {}
650+
for n in attr_names:
651+
if n == 'sub_block':
652+
attr_map[n] = self.block_attr(n)
653+
else:
654+
attr_map[n] = self.attr(n)
655+
return attr_map
656+
643657

644658
class Block(object):
645659
def __init__(self, program, idx):

python/paddle/fluid/layers/io.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,32 @@ def _copy_reader_var_(block, var):
255255
new_var.desc.set_shapes(var.desc.shapes())
256256
new_var.desc.set_dtypes(var.desc.dtypes())
257257
new_var.persistable = True
258-
return monkey_patch_reader_methods(new_var)
258+
return new_var
259+
260+
261+
def _copy_reader_create_op_(block, op):
262+
input_param_names = op.input_names
263+
new_input_map = {}
264+
for param_name in input_param_names:
265+
new_input_map[param_name] = []
266+
arg_names = op.input(param_name)
267+
for arg_name in arg_names:
268+
new_input_map[param_name].append(block.var(arg_name))
269+
270+
output_param_names = op.output_names
271+
new_output_map = {}
272+
for param_name in output_param_names:
273+
new_output_map[param_name] = []
274+
arg_names = op.output(param_name)
275+
for arg_name in arg_names:
276+
new_output_map[param_name].append(block.var(arg_name))
277+
278+
new_op = block.append_op(
279+
type=op.type,
280+
inputs=new_input_map,
281+
outputs=new_output_map,
282+
attrs=op.all_attrs())
283+
return new_op
259284

260285

261286
def open_recordio_file(filename, shapes, lod_levels, dtypes):
@@ -283,8 +308,9 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
283308

284309
startup_var.desc.set_dtypes(dtypes)
285310
startup_var.persistable = True
286-
return _copy_reader_var_(default_main_program().current_block(),
287-
startup_var)
311+
main_prog_var = _copy_reader_var_(default_main_program().current_block(),
312+
startup_var)
313+
return monkey_patch_reader_methods(main_prog_var)
288314

289315

290316
def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
@@ -313,22 +339,25 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
313339

314340
startup_var.desc.set_dtypes(dtypes)
315341
startup_var.persistable = True
316-
return _copy_reader_var_(default_main_program().current_block(),
317-
startup_var)
342+
main_prog_var = _copy_reader_var_(default_main_program().current_block(),
343+
startup_var)
344+
return monkey_patch_reader_methods(main_prog_var)
318345

319346

320347
def __create_decorated_reader__(op_type, reader, attrs):
321348
var_name = unique_name(op_type)
322349
startup_blk = default_startup_program().current_block()
323350
startup_var = startup_blk.create_var(name=var_name)
324-
startup_blk.append_op(
351+
startop_op = startup_blk.append_op(
325352
type=op_type,
326353
inputs={'UnderlyingReader': reader},
327354
outputs={'Out': [startup_var]},
328355
attrs=attrs)
329356
startup_var.persistable = True
330-
return _copy_reader_var_(default_main_program().current_block(),
331-
startup_var)
357+
main_prog_block = default_main_program().current_block()
358+
main_prog_var = _copy_reader_var_(main_prog_block, startup_var)
359+
_copy_reader_create_op_(main_prog_block, startop_op)
360+
return monkey_patch_reader_methods(main_prog_var)
332361

333362

334363
def create_shuffle_reader(reader, buffer_size):

python/paddle/fluid/tests/unittests/test_recordio_reader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
import unittest
1616

1717
import paddle.fluid as fluid
18-
import paddle
19-
import paddle.dataset.mnist as mnist
18+
import paddle.v2 as paddle
19+
import paddle.v2.dataset.mnist as mnist
2020

2121

2222
class TestRecordIO(unittest.TestCase):

0 commit comments

Comments
 (0)