Skip to content

Commit dd1244f

Browse files
authored
Merge pull request #8943 from JiayiFeng/fix_bugs_in_readers
Fix a potential bug in the c++ reader
2 parents 6c06841 + 614c33f commit dd1244f

File tree

2 files changed

+38
-21
lines changed

2 files changed

+38
-21
lines changed

paddle/fluid/framework/reader.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,25 @@ class ReaderHolder {
6565

6666
ReaderBase* Get() const { return reader_.get(); }
6767

68-
void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); }
69-
void ReInit() { reader_->ReInit(); }
68+
void ReadNext(std::vector<LoDTensor>* out) {
69+
PADDLE_ENFORCE_NOT_NULL(reader_);
70+
reader_->ReadNext(out);
71+
}
72+
void ReInit() {
73+
PADDLE_ENFORCE_NOT_NULL(reader_);
74+
reader_->ReInit();
75+
}
7076

71-
DDim shape(size_t idx) const { return reader_->shape(idx); }
72-
std::vector<DDim> shapes() const { return reader_->shapes(); }
77+
DDim shape(size_t idx) const {
78+
PADDLE_ENFORCE_NOT_NULL(reader_);
79+
return reader_->shape(idx);
80+
}
81+
std::vector<DDim> shapes() const {
82+
PADDLE_ENFORCE_NOT_NULL(reader_);
83+
return reader_->shapes();
84+
}
7385
void set_shapes(const std::vector<DDim>& shapes) {
86+
PADDLE_ENFORCE_NOT_NULL(reader_);
7487
reader_->set_shapes(shapes);
7588
}
7689

paddle/fluid/operators/reader/reader_op_registry.cc

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,23 +49,25 @@ FileReaderMakerBase::FileReaderMakerBase(
4949
}
5050

5151
void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
52+
PADDLE_ENFORCE(
53+
!ctx->IsRuntime(),
54+
"'FileReaderInferShape' should only be invoked during compile time.");
55+
5256
PADDLE_ENFORCE(ctx->HasOutput("Out"),
5357
"The output file reader should not be null.");
5458
const auto shape_concat = ctx->Attrs().Get<std::vector<int>>("shape_concat");
5559
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks");
5660
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
5761
ctx->SetReaderDims("Out", shapes);
5862

59-
if (ctx->IsRuntime()) {
60-
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
61-
PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(),
62-
"The number of 'lod_levels'(%d) doesn't match the number "
63-
"of 'shapes'(%d).",
64-
lod_levels.size(), shapes.size());
65-
framework::VarDesc* reader =
66-
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
67-
reader->SetLoDLevels(lod_levels);
68-
}
63+
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
64+
PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(),
65+
"The number of 'lod_levels'(%d) doesn't match the number "
66+
"of 'shapes'(%d).",
67+
lod_levels.size(), shapes.size());
68+
framework::VarDesc* reader =
69+
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
70+
reader->SetLoDLevels(lod_levels);
6971
}
7072

7173
void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
@@ -77,19 +79,21 @@ void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
7779

7880
void DecoratedReaderInferShape::operator()(
7981
framework::InferShapeContext* ctx) const {
82+
PADDLE_ENFORCE(!ctx->IsRuntime(),
83+
"'DecoratedReaderInferShape' should only be invoked during "
84+
"compile time.");
85+
8086
PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"),
8187
"Input(UnderlyingReader) should not be null.");
8288
PADDLE_ENFORCE(ctx->HasOutput("Out"),
8389
"The output decorated reader should not be null.");
8490
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader"));
8591

86-
if (ctx->IsRuntime()) {
87-
framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
88-
ctx->GetInputVarPtrs("UnderlyingReader")[0]);
89-
framework::VarDesc* out_reader =
90-
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
91-
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
92-
}
92+
framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
93+
ctx->GetInputVarPtrs("UnderlyingReader")[0]);
94+
framework::VarDesc* out_reader =
95+
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
96+
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
9397
}
9498
void DecoratedReaderInferVarType::operator()(
9599
const framework::OpDesc& op_desc, framework::BlockDesc* block) const {

0 commit comments

Comments
 (0)