Skip to content

Commit a6c1bff

Browse files
committed
Merge with upstream
2 parents bc7be83 + 175aa7e commit a6c1bff

File tree

7 files changed

+89
-19
lines changed

7 files changed

+89
-19
lines changed

paddle/fluid/framework/op_desc.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
7777
void SetRepeatedDims(const std::string &name,
7878
const std::vector<DDim> &dims) override;
7979

80+
InferShapeVarPtr GetVarPtr(const std::string &name) override;
81+
8082
const OpDesc &op_;
8183
const BlockDesc &block_;
8284
};
@@ -510,5 +512,10 @@ proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
510512
return block_.FindVarRecursive(name)->GetType();
511513
}
512514

515+
InferShapeVarPtr CompileTimeInferShapeContext::GetVarPtr(
516+
const std::string &name) {
517+
return block_.FindVarRecursive(name);
518+
}
519+
513520
} // namespace framework
514521
} // namespace paddle

paddle/fluid/framework/operator.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,10 @@ class RuntimeInferShapeContext : public InferShapeContext {
470470
return ToVarType(var->Type());
471471
}
472472

473+
InferShapeVarPtr GetVarPtr(const std::string& name) override {
474+
return scope_.FindVar(name);
475+
}
476+
473477
private:
474478
const OperatorBase& op_;
475479
const Scope& scope_;

paddle/fluid/framework/reader.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
9090

9191
// Merge lod and data
9292
LoD batch_lod;
93-
std::vector<size_t> top_level_lod({0});
9493
for (size_t i = 0; i < buffer_.size(); ++i) {
9594
DDim ins_shape = buffer_[i][j].dims();
9695
LoD ins_lod = buffer_[i][j].lod();
@@ -105,15 +104,10 @@ void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
105104
}
106105
}
107106
}
108-
top_level_lod.push_back(
109-
top_level_lod.back() +
110-
(ins_lod.empty() ? ins_shape[0] : (ins_lod[0].size() - 1)));
111-
112107
Tensor dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]);
113108
Copy(buffer_[i][j], platform::CPUPlace(), &dst);
114109
dst_offset += ins_shape[0];
115110
}
116-
batch_lod.insert(batch_lod.begin(), top_level_lod);
117111
out_tensor.set_lod(batch_lod);
118112
out->push_back(out_tensor);
119113
}

paddle/fluid/framework/shape_inference.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,28 @@ void InferShapeContext::SetReaderDims(const std::string &name,
7272
return this->SetRepeatedDims(arg_names[0], dims);
7373
}
7474

75+
std::vector<InferShapeVarPtr> InferShapeContext::GetInputVarPtrs(
76+
const std::string &name) {
77+
const std::vector<std::string> arg_names = Inputs(name);
78+
std::vector<InferShapeVarPtr> res;
79+
res.reserve(arg_names.size());
80+
std::transform(
81+
arg_names.begin(), arg_names.end(), std::back_inserter(res),
82+
[this](const std::string &name) { return this->GetVarPtr(name); });
83+
return res;
84+
}
85+
86+
std::vector<InferShapeVarPtr> InferShapeContext::GetOutputVarPtrs(
87+
const std::string &name) {
88+
const std::vector<std::string> arg_names = Outputs(name);
89+
std::vector<InferShapeVarPtr> res;
90+
res.reserve(arg_names.size());
91+
std::transform(
92+
arg_names.begin(), arg_names.end(), std::back_inserter(res),
93+
[this](const std::string &name) { return this->GetVarPtr(name); });
94+
return res;
95+
}
96+
7597
std::vector<DDim> InferShapeContext::GetDims(
7698
const std::vector<std::string> &names) const {
7799
std::vector<DDim> ret;

paddle/fluid/framework/shape_inference.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@ limitations under the License. */
1717
#include "paddle/fluid/framework/attribute.h"
1818
#include "paddle/fluid/framework/ddim.h"
1919
#include "paddle/fluid/framework/framework.pb.h"
20+
#include "paddle/fluid/framework/var_desc.h"
21+
#include "paddle/fluid/framework/variable.h"
2022

2123
namespace paddle {
2224
namespace framework {
2325

26+
using InferShapeVarPtr = boost::variant<VarDesc *, Variable *>;
27+
2428
class InferShapeContext {
2529
public:
2630
virtual ~InferShapeContext() = default;
@@ -55,6 +59,9 @@ class InferShapeContext {
5559

5660
virtual bool IsRuntime() const = 0;
5761

62+
std::vector<InferShapeVarPtr> GetInputVarPtrs(const std::string &name);
63+
std::vector<InferShapeVarPtr> GetOutputVarPtrs(const std::string &name);
64+
5865
// Note: In while op, we need this to be public
5966
void SetDims(const std::vector<std::string> &names,
6067
const std::vector<DDim> &dims);
@@ -67,10 +74,13 @@ class InferShapeContext {
6774
const std::vector<DDim> &dims) = 0;
6875

6976
std::vector<DDim> GetDims(const std::vector<std::string> &names) const;
77+
7078
std::vector<proto::VarDesc::VarType> GetVarTypes(
7179
const std::vector<std::string> &names) const;
7280

7381
virtual proto::VarDesc::VarType GetVarType(const std::string &name) const = 0;
82+
83+
virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0;
7484
};
7585

7686
} // namespace framework

paddle/fluid/operators/create_reader_op.cc

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,18 @@ class CreateFileReaderInferShape : public framework::InferShapeBase {
4242
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks");
4343
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
4444
ctx->SetReaderDims("Out", shapes);
45+
46+
if (ctx->IsRuntime()) {
47+
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
48+
PADDLE_ENFORCE_EQ(
49+
lod_levels.size(), shapes.size(),
50+
"The number of 'lod_levels'(%d) doesn't match the number "
51+
"of 'shapes'(%d).",
52+
lod_levels.size(), shapes.size());
53+
framework::VarDesc* reader =
54+
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
55+
reader->SetLoDLevels(lod_levels);
56+
}
4557
}
4658
};
4759

@@ -54,11 +66,19 @@ class CreateDecoratedReaderInferShape : public framework::InferShapeBase {
5466
PADDLE_ENFORCE(ctx->HasOutput("Out"),
5567
"The output decorated reader should not be null.");
5668
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader"));
69+
70+
if (ctx->IsRuntime()) {
71+
framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
72+
ctx->GetInputVarPtrs("UnderlyingReader")[0]);
73+
framework::VarDesc* out_reader =
74+
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
75+
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
76+
}
5777
}
5878
};
5979

60-
// general var type inference for all readers
61-
class CreateReaderInferVarType : public framework::VarTypeInference {
80+
// general var type inference for file readers
81+
class CreateFileReaderInferVarType : public framework::VarTypeInference {
6282
public:
6383
void operator()(const framework::OpDesc& op_desc,
6484
framework::BlockDesc* block) const override {
@@ -68,6 +88,20 @@ class CreateReaderInferVarType : public framework::VarTypeInference {
6888
}
6989
};
7090

91+
// general var type inference for decorated readers
92+
class CreateDecoratedReaderInferVarType : public framework::VarTypeInference {
93+
public:
94+
void operator()(const framework::OpDesc& op_desc,
95+
framework::BlockDesc* block) const override {
96+
std::string in_reader_name = op_desc.Input("UnderlyingReader")[0];
97+
framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name);
98+
std::string out_reader_name = op_desc.Output("Out")[0];
99+
framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name);
100+
out_reader->SetType(framework::proto::VarDesc::READER);
101+
out_reader->SetDataTypes(in_reader->GetDataTypes());
102+
}
103+
};
104+
71105
template <typename T>
72106
class CreateRandomDataGeneratorOp : public framework::OperatorBase {
73107
public:
@@ -105,6 +139,7 @@ class CreateRandomDataGeneratorOpMaker
105139
"ranks = [3,2]"
106140
"It means the reader will generate two data each time,"
107141
"whose shapes are [2,3,4] and [5,6] respectively.");
142+
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
108143
AddAttr<float>("min", "The lower bound of reader's uniform distribution.");
109144
AddAttr<float>("max", "The upper bound of reader's uniform distribution.");
110145
AddComment(R"DOC(
@@ -192,14 +227,14 @@ REGISTER_OPERATOR(create_random_data_generator,
192227
ops::CreateFileReaderInferShape,
193228
ops::CreateRandomDataGeneratorOpMaker,
194229
paddle::framework::EmptyGradOpMaker,
195-
ops::CreateReaderInferVarType);
230+
ops::CreateFileReaderInferVarType);
196231
REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp,
197232
ops::CreateDecoratedReaderInferShape,
198233
ops::CreateShuffleReaderOpMaker,
199234
paddle::framework::EmptyGradOpMaker,
200-
ops::CreateReaderInferVarType);
235+
ops::CreateDecoratedReaderInferVarType);
201236
REGISTER_OPERATOR(create_batch_reader, ops::CreateBatchReaderOp,
202237
ops::CreateDecoratedReaderInferShape,
203238
ops::CreateBatchReaderOpMaker,
204239
paddle::framework::EmptyGradOpMaker,
205-
ops::CreateReaderInferVarType);
240+
ops::CreateDecoratedReaderInferVarType);

python/paddle/v2/fluid/tests/test_cpp_reader.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
random_reader = block.create_var(
2323
type=fluid.core.VarDesc.VarType.READER, name="RandomDataGenerator")
24-
random_reader.desc.set_lod_levels([0, 0])
24+
random_reader.desc.set_dtypes(
25+
[fluid.core.DataType.FP32, fluid.core.DataType.FP32])
2526

2627
create_random_data_generator_op = block.append_op(
2728
type="create_random_data_generator",
@@ -30,11 +31,11 @@
3031
"shape_concat": [1, 2, 1, 1],
3132
"ranks": [2, 2],
3233
"min": 0.0,
33-
"max": 1.0
34+
"max": 1.0,
35+
'lod_levels': [0, 0]
3436
})
3537
shuffle_reader = block.create_var(
3638
type=fluid.core.VarDesc.VarType.READER, name="ShuffleReader")
37-
shuffle_reader.desc.set_lod_levels([0, 0])
3839

3940
create_shuffle_reader_op = block.append_op(
4041
type="create_shuffle_reader",
@@ -44,7 +45,6 @@
4445

4546
batch_reader = block.create_var(
4647
type=fluid.core.VarDesc.VarType.READER, name="BatchReader")
47-
batch_reader.desc.set_lod_levels([1, 1])
4848

4949
create_batch_reader_op = block.append_op(
5050
type="create_batch_reader",
@@ -62,11 +62,9 @@
6262
place = fluid.CPUPlace()
6363
exe = fluid.Executor(place)
6464

65-
[res1, res2] = exe.run(prog, fetch_list=[out1, out2], return_numpy=False)
65+
[res1, res2] = exe.run(prog, fetch_list=[out1, out2])
6666

67-
test_pass = res1.lod() == [range(0, 11)] and res1.lod() == [
68-
range(0, 11)
69-
] and np.array(res1).shape == (10, 2) and np.array(res2).shape == (10, 1)
67+
test_pass = res1.shape == (10, 2) and res2.shape == (10, 1)
7068

7169
if not test_pass:
7270
exit(1)

0 commit comments

Comments
 (0)