Skip to content

Commit cd8700f

Browse files
authored
Merge pull request #10872 from JiayiFeng/dev_CustomReader
CustomReader
2 parents 7530366 + 8147063 commit cd8700f

File tree

6 files changed

+365
-5
lines changed

6 files changed

+365
-5
lines changed

paddle/fluid/framework/shape_inference.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class InferShapeContext {
6363

6464
std::vector<InferShapeVarPtr> GetInputVarPtrs(const std::string &name);
6565
std::vector<InferShapeVarPtr> GetOutputVarPtrs(const std::string &name);
66+
virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0;
6667

6768
// Note: In while op, we need this to be public
6869
void SetDims(const std::vector<std::string> &names,
@@ -81,8 +82,6 @@ class InferShapeContext {
8182
const std::vector<std::string> &names) const;
8283

8384
virtual proto::VarType::Type GetVarType(const std::string &name) const = 0;
84-
85-
virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0;
8685
};
8786

8887
} // namespace framework

paddle/fluid/operators/reader/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_o
2323
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
2424
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
2525
reader_library(create_threaded_reader_op SRCS create_threaded_reader_op.cc)
26+
reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc)
2627

2728
cc_test(reader_blocking_queue_test SRCS reader_blocking_queue_test.cc)
2829
# Export local libraries to parent
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/executor.h"
16+
#include "paddle/fluid/operators/detail/safe_ref.h"
17+
#include "paddle/fluid/operators/reader/reader_op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
namespace reader {
22+
23+
class CustomReader : public framework::DecoratedReader {
24+
public:
25+
CustomReader(ReaderBase* reader, const framework::BlockDesc& sub_block,
26+
const platform::Place& dev_place,
27+
const std::vector<std::string>& source_var_names,
28+
const std::vector<std::string>& sink_var_names)
29+
: DecoratedReader(reader),
30+
program_(*sub_block.Program()),
31+
sub_block_id_(sub_block.ID()),
32+
exe_(framework::Executor(dev_place)),
33+
source_var_names_(source_var_names),
34+
sink_var_names_(sink_var_names) {}
35+
36+
void ReadNext(std::vector<framework::LoDTensor>* out) override;
37+
38+
private:
39+
const framework::ProgramDesc program_;
40+
int sub_block_id_;
41+
framework::Executor exe_;
42+
43+
std::vector<std::string> source_var_names_;
44+
std::vector<std::string> sink_var_names_;
45+
};
46+
47+
class CreateCustomReaderOp : public framework::OperatorBase {
48+
public:
49+
using framework::OperatorBase::OperatorBase;
50+
51+
private:
52+
void RunImpl(const framework::Scope& scope,
53+
const platform::Place& dev_place) const override {
54+
auto* out = scope.FindVar(Output("Out"))
55+
->template GetMutable<framework::ReaderHolder>();
56+
auto* sub_block = Attr<framework::BlockDesc*>("sub_block");
57+
if (out->Get() != nullptr) {
58+
return;
59+
}
60+
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
61+
->Get<framework::ReaderHolder>();
62+
out->Reset(
63+
new CustomReader(underlying_reader.Get(), *sub_block, dev_place,
64+
Attr<std::vector<std::string>>("source_var_names"),
65+
Attr<std::vector<std::string>>("sink_var_names")));
66+
}
67+
};
68+
69+
class CreateCustomReaderOpMaker : public DecoratedReaderMakerBase {
70+
protected:
71+
void Apply() override {
72+
AddAttr<framework::BlockDesc*>(
73+
"sub_block", "The block to hold all preprocessing operators.");
74+
AddAttr<std::vector<std::string>>(
75+
"source_var_names",
76+
"Source variables are starting points of data preprocessing. They hold "
77+
"preprocessing's input tensors. Each source variable corresponds to "
78+
"one of underlying reader's output datas.");
79+
AddAttr<std::vector<std::string>>(
80+
"sink_var_names",
81+
"Sink variables are ending points of data preprocessing. They hold "
82+
"preprocessing's output tensors. Each sink variable corresponds to "
83+
"one of custom reader's output datas.");
84+
AddComment(R"DOC(
85+
CreateCustomReader Operator
86+
87+
A custom reader can be used for input data preprocessing.
88+
A custom reader holds its own sub-block, which will be executed in its
89+
'ReadNext()' function. Users can configurate their own preprocessing
90+
pipelines by inserting operators into custom reader's sub-block.
91+
)DOC");
92+
}
93+
};
94+
95+
class CustomReaderInferShape : public framework::InferShapeBase {
96+
public:
97+
void operator()(framework::InferShapeContext* ctx) const override {
98+
PADDLE_ENFORCE(!ctx->IsRuntime(),
99+
"'CustomReaderInferShape' should only be invoked during "
100+
"compile time.");
101+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
102+
"The output decorated reader should not be null.");
103+
const auto* sub_block =
104+
ctx->Attrs().Get<framework::BlockDesc*>("sub_block");
105+
const auto sink_var_names =
106+
ctx->Attrs().Get<std::vector<std::string>>("sink_var_names");
107+
std::vector<std::vector<int64_t>> res_dims;
108+
std::vector<int32_t> res_lod_levels;
109+
for (const std::string& var_name : sink_var_names) {
110+
auto* sink_var = sub_block->FindVar(var_name);
111+
PADDLE_ENFORCE_NOT_NULL(sink_var);
112+
res_dims.emplace_back(sink_var->GetShape());
113+
res_lod_levels.push_back(sink_var->GetLoDLevel());
114+
}
115+
auto* out_reader =
116+
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
117+
out_reader->SetShapes(res_dims);
118+
out_reader->SetLoDLevels(res_lod_levels);
119+
}
120+
};
121+
122+
class CustomReaderInferVarType : public framework::VarTypeInference {
123+
public:
124+
void operator()(const framework::OpDesc& op_desc,
125+
framework::BlockDesc* block) const override {
126+
framework::VarDesc* out_reader = block->FindVar(op_desc.Output("Out")[0]);
127+
PADDLE_ENFORCE_NOT_NULL(out_reader);
128+
out_reader->SetType(framework::proto::VarType::READER);
129+
130+
auto sink_var_names =
131+
boost::get<std::vector<std::string>>(op_desc.GetAttr("sink_var_names"));
132+
const auto* sub_block =
133+
boost::get<framework::BlockDesc*>(op_desc.GetAttr("sub_block"));
134+
std::vector<framework::proto::VarType::Type> res_data_types;
135+
for (const std::string& var_name : sink_var_names) {
136+
framework::VarDesc* var = sub_block->FindVar(var_name);
137+
PADDLE_ENFORCE_NOT_NULL(var);
138+
res_data_types.emplace_back(var->GetDataType());
139+
}
140+
out_reader->SetDataTypes(res_data_types);
141+
}
142+
};
143+
144+
void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
145+
out->clear();
146+
std::vector<framework::LoDTensor> underlying_outs;
147+
reader_->ReadNext(&underlying_outs);
148+
if (underlying_outs.empty()) {
149+
// There is not next data.
150+
return;
151+
}
152+
PADDLE_ENFORCE(source_var_names_.size() == underlying_outs.size(),
153+
"The size of source_var_names(%d) and the size of "
154+
"underlying_outs(%d) are not consistent. Each feeding element "
155+
"must have its own source variable.",
156+
source_var_names_.size(), underlying_outs.size());
157+
// The scope for CustomReader's sub-block should be independent and shouldn't
158+
// be any other computation scope's child. Otherwise, data preprocessing and
159+
// compution cannot be concurrent.
160+
framework::Scope scope;
161+
// 1. Copy LoDTensors from underlying reader's output to source variables.
162+
for (size_t i = 0; i < source_var_names_.size(); ++i) {
163+
framework::Variable* var = scope.Var(source_var_names_[i]);
164+
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
165+
tensor->ShareDataWith(underlying_outs[i]);
166+
tensor->set_lod(underlying_outs[i].lod());
167+
}
168+
// 2. Run the sub-block.
169+
exe_.Run(program_, &scope, sub_block_id_, false, true);
170+
// 3. Copy LoDTensors from sink variables to out.
171+
out->resize(sink_var_names_.size());
172+
for (size_t i = 0; i < sink_var_names_.size(); ++i) {
173+
const auto& tensor = detail::Ref(scope.FindVar(sink_var_names_[i]))
174+
.Get<framework::LoDTensor>();
175+
framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]);
176+
}
177+
}
178+
179+
} // namespace reader
180+
} // namespace operators
181+
} // namespace paddle
182+
183+
namespace ops = paddle::operators::reader;
184+
REGISTER_OPERATOR(create_custom_reader, ops::CreateCustomReaderOp,
185+
ops::CreateCustomReaderOpMaker, ops::CustomReaderInferShape,
186+
ops::CustomReaderInferVarType,
187+
paddle::framework::EmptyGradOpMaker)

paddle/fluid/operators/reader/reader_op_registry.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ void DecoratedReaderInferShape::operator()(
115115
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
116116
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
117117
}
118+
118119
void DecoratedReaderInferVarType::operator()(
119120
const framework::OpDesc& op_desc, framework::BlockDesc* block) const {
120121
std::string in_reader_name = op_desc.Input("UnderlyingReader")[0];

python/paddle/fluid/layers/io.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import contextlib
1415

1516
from .. import core
1617
from ..framework import convert_np_dtype_to_dtype_, default_main_program, default_startup_program, Program
@@ -21,7 +22,8 @@
2122

2223
__all__ = [
2324
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
24-
'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer'
25+
'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer',
26+
'Preprocessor'
2527
]
2628

2729

@@ -535,8 +537,6 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None):
535537
inputs={'UnderlyingReader': reader},
536538
outputs={'Out': [new_reader]},
537539
attrs=attrs)
538-
new_reader.persistable = True
539-
new_reader.stop_gradient = True
540540
return monkey_patch_reader_methods(new_reader)
541541

542542

@@ -581,3 +581,82 @@ def read_file(file_obj):
581581
return out[0]
582582
else:
583583
return out
584+
585+
586+
class Preprocessor(object):
587+
BEFORE_SUB_BLOCK = 0
588+
IN_SUB_BLOCK = 1
589+
AFTER_SUB_BLOCK = 2
590+
591+
def __init__(self, reader, name=None):
592+
self.underlying_reader = reader
593+
new_reader_name = name if name is not None else unique_name(
594+
"create_custom_reader")
595+
self.main_prog = default_main_program()
596+
self.reader = self.main_prog.current_block().create_var(
597+
name=new_reader_name)
598+
self.sub_block = None
599+
self.source_var_names = None
600+
self.sink_var_names = None
601+
self.status = Preprocessor.BEFORE_SUB_BLOCK
602+
603+
def is_completed(self):
604+
return self.sub_block and self.source_var_names and self.sink_var_names
605+
606+
@contextlib.contextmanager
607+
def block(self):
608+
self.status = Preprocessor.IN_SUB_BLOCK
609+
self.sub_block = self.main_prog.create_block()
610+
yield
611+
self.main_prog.rollback()
612+
self.status = Preprocessor.AFTER_SUB_BLOCK
613+
if not self.is_completed():
614+
raise RuntimeError(
615+
"The definition of preprocessor is incompleted! "
616+
"Please make sure that you have set input and output "
617+
"variables by invoking 'inputs' and 'outputs' in "
618+
"Preprocessor's sub-block.")
619+
620+
def inputs(self):
621+
if self.status != Preprocessor.IN_SUB_BLOCK:
622+
raise RuntimeError(
623+
"Preprocessor.inputs() can only be invoked inside the sub-block."
624+
)
625+
626+
source_shapes = self.underlying_reader.desc.shapes()
627+
source_dtypes = self.underlying_reader.desc.dtypes()
628+
source_lod_levels = self.underlying_reader.desc.lod_levels()
629+
self.source_var_names = [
630+
unique_name("preprocessor_source")
631+
for _ in xrange(len(source_shapes))
632+
]
633+
source_vars = []
634+
for var_name, shape, dtype, lod_level in zip(
635+
self.source_var_names, source_shapes, source_dtypes,
636+
source_lod_levels):
637+
source_vars.append(self.main_prog.current_block().create_var(
638+
name=var_name, shape=shape, dtype=dtype, lod_level=lod_level))
639+
return source_vars
640+
641+
def outputs(self, *outs):
642+
if self.status != Preprocessor.IN_SUB_BLOCK:
643+
raise RuntimeError(
644+
"Preprocessor.outputs() can only be invoked inside the sub-block."
645+
)
646+
self.sink_var_names = [var.name for var in outs]
647+
648+
def __call__(self, *args, **kwargs):
649+
if self.status != Preprocessor.AFTER_SUB_BLOCK:
650+
raise RuntimeError(
651+
"Preprocessor output can only be retrieved after rnn block.")
652+
653+
self.main_prog.current_block().append_op(
654+
type="create_custom_reader",
655+
inputs={'UnderlyingReader': self.underlying_reader},
656+
outputs={'Out': [self.reader]},
657+
attrs={
658+
"sub_block": self.sub_block,
659+
"source_var_names": self.source_var_names,
660+
"sink_var_names": self.sink_var_names
661+
})
662+
return monkey_patch_reader_methods(self.reader)

0 commit comments

Comments
 (0)