@@ -49,23 +49,25 @@ FileReaderMakerBase::FileReaderMakerBase(
49
49
}
50
50
51
51
void FileReaderInferShape::operator ()(framework::InferShapeContext* ctx) const {
52
+ PADDLE_ENFORCE (
53
+ !ctx->IsRuntime (),
54
+ " 'FileReaderInferShape' should only be invoked during compile time." );
55
+
52
56
PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
53
57
" The output file reader should not be null." );
54
58
const auto shape_concat = ctx->Attrs ().Get <std::vector<int >>(" shape_concat" );
55
59
const auto ranks = ctx->Attrs ().Get <std::vector<int >>(" ranks" );
56
60
std::vector<framework::DDim> shapes = RestoreShapes (shape_concat, ranks);
57
61
ctx->SetReaderDims (" Out" , shapes);
58
62
59
- if (ctx->IsRuntime ()) {
60
- const auto lod_levels = ctx->Attrs ().Get <std::vector<int >>(" lod_levels" );
61
- PADDLE_ENFORCE_EQ (lod_levels.size (), shapes.size (),
62
- " The number of 'lod_levels'(%d) doesn't match the number "
63
- " of 'shapes'(%d)." ,
64
- lod_levels.size (), shapes.size ());
65
- framework::VarDesc* reader =
66
- boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs (" Out" )[0 ]);
67
- reader->SetLoDLevels (lod_levels);
68
- }
63
+ const auto lod_levels = ctx->Attrs ().Get <std::vector<int >>(" lod_levels" );
64
+ PADDLE_ENFORCE_EQ (lod_levels.size (), shapes.size (),
65
+ " The number of 'lod_levels'(%d) doesn't match the number "
66
+ " of 'shapes'(%d)." ,
67
+ lod_levels.size (), shapes.size ());
68
+ framework::VarDesc* reader =
69
+ boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs (" Out" )[0 ]);
70
+ reader->SetLoDLevels (lod_levels);
69
71
}
70
72
71
73
void FileReaderInferVarType::operator ()(const framework::OpDesc& op_desc,
@@ -77,19 +79,21 @@ void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
77
79
78
80
void DecoratedReaderInferShape::operator ()(
79
81
framework::InferShapeContext* ctx) const {
82
+ PADDLE_ENFORCE (!ctx->IsRuntime (),
83
+ " 'DecoratedReaderInferShape' should only be invoked during "
84
+ " compile time." );
85
+
80
86
PADDLE_ENFORCE (ctx->HasInput (" UnderlyingReader" ),
81
87
" Input(UnderlyingReader) should not be null." );
82
88
PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
83
89
" The output decorated reader should not be null." );
84
90
ctx->SetReaderDims (" Out" , ctx->GetReaderDims (" UnderlyingReader" ));
85
91
86
- if (ctx->IsRuntime ()) {
87
- framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
88
- ctx->GetInputVarPtrs (" UnderlyingReader" )[0 ]);
89
- framework::VarDesc* out_reader =
90
- boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs (" Out" )[0 ]);
91
- out_reader->SetLoDLevels (in_reader->GetLoDLevels ());
92
- }
92
+ framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
93
+ ctx->GetInputVarPtrs (" UnderlyingReader" )[0 ]);
94
+ framework::VarDesc* out_reader =
95
+ boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs (" Out" )[0 ]);
96
+ out_reader->SetLoDLevels (in_reader->GetLoDLevels ());
93
97
}
94
98
void DecoratedReaderInferVarType::operator ()(
95
99
const framework::OpDesc& op_desc, framework::BlockDesc* block) const {
0 commit comments