Skip to content

Commit a6910f9

Browse files
authored
Always create variables in analysis_predictor before OptimizeInferenceProgram. (#15533)
Otherwise, some other persistable variable (like RAW type) will not be created
1 parent 748c2d3 commit a6910f9

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -123,22 +123,22 @@ bool AnalysisPredictor::PrepareProgram(
123123
if (!program) {
124124
if (!LoadProgramDesc()) return false;
125125

126+
// If not cloned, the parameters should be loaded.
127+
// If config_.ir_optim() is True, parameters is loaded in
128+
// OptimizeInferenceProgram(), but other persistable variables
129+
// (like RAW type var) are not created in scope.
130+
// If config_.ir_optim() is False, parameters is loaded in LoadParameters(),
131+
// still need to create other persistable variables.
132+
// So in both case, create persistable variables at first.
133+
executor_->CreateVariables(*inference_program_, 0, true, sub_scope_);
134+
126135
// Optimize the program, and load parameters and modify them in the
127136
// scope_.
128137
// This will change the scope_ address.
129138
if (config_.ir_optim()) {
130139
status_ir_optim_enabled_ = true;
131140
OptimizeInferenceProgram();
132141
} else {
133-
// If the parent_scope is passed, we assert that the persistable variables
134-
// are already created, so just create the no persistable variables.
135-
136-
// If not cloned, the parameters should be loaded
137-
// OptimizeInferenceProgram.
138-
// So in both cases, just the local variables are needed to load, not the
139-
// parematers.
140-
executor_->CreateVariables(*inference_program_, 0, true, sub_scope_);
141-
142142
// Load parameters
143143
LOG(INFO) << "load parameters ";
144144
LoadParameters();
@@ -376,7 +376,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
376376
}
377377
argument_.SetIrAnalysisPasses(passes);
378378
argument_.SetAnalysisPasses(config_.pass_builder()->AnalysisPasses());
379-
argument_.SetScopeNotOwned(const_cast<framework::Scope *>(scope_.get()));
379+
argument_.SetScopeNotOwned(scope_.get());
380380
Analyzer().Run(&argument_);
381381

382382
PADDLE_ENFORCE(argument_.scope_valid());

0 commit comments

Comments
 (0)