@@ -123,20 +123,41 @@ void ProgramInterpreter::RunImpl() {
123
123
#endif
124
124
}
125
125
126
- FetchList ProgramInterpreter::Run (
127
- const std::vector<std::string>& feed_names,
128
- const std::vector<phi::DenseTensor>& feed_tensors) {
126
+ FetchList ProgramInterpreter::Run (const std::vector<std::string>& feed_names,
127
+ bool need_fetch) {
129
128
SetDeviceId (place_);
130
129
CheckCUDAGraphBeforeRun (feed_names);
131
130
132
131
#ifdef PADDLE_WITH_DNNL
133
132
platform::AttachPointerHashToMKLDNNKey (this , place_);
134
133
#endif
135
134
136
- bool is_build = is_build_;
137
- Prepare (feed_names, feed_tensors, is_build);
135
+ if (!is_build_) {
136
+ LOG_FIRST_N (INFO, 1 ) << " New Executor is Running." ;
137
+ paddle::framework::interpreter::BuildVariableScope (
138
+ block_, execution_config_, &var_scope_);
138
139
139
- if (is_build) {
140
+ std::vector<paddle::framework::OpFuncNode> op_func_nodes;
141
+ paddle::framework::interpreter::BuildOpFuncList (
142
+ place_,
143
+ block_,
144
+ execution_config_.skip_gc_vars ,
145
+ &op_func_nodes,
146
+ &var_scope_,
147
+ execution_config_,
148
+ HasLocalScope (),
149
+ static_build_);
150
+ SetFeedVarsInplaceSkip (feed_names);
151
+ // convert vec func_list to graph
152
+ Convert (&op_func_nodes);
153
+ UpdateSyncOpNum ();
154
+ if (static_build_) {
155
+ VLOG (4 ) << " RUN impl" ;
156
+ RunImpl ();
157
+ }
158
+ is_build_ = true ;
159
+ is_shared_results_build_ = true ;
160
+ } else {
140
161
RunImpl ();
141
162
}
142
163
@@ -145,8 +166,10 @@ FetchList ProgramInterpreter::Run(
145
166
}
146
167
147
168
// return Fetch Tensors
148
- auto * fetch_var = local_scope_->FindVar (interpreter::kFetchVarName );
149
- if (fetch_var) {
169
+ Scope* inner_scope =
170
+ HasLocalScope () ? local_scope_ : var_scope_.GetMutableScope ();
171
+ auto * fetch_var = inner_scope->FindVar (interpreter::kFetchVarName );
172
+ if (fetch_var && need_fetch) {
150
173
auto fetch_list = std::move (*fetch_var->GetMutable <framework::FetchList>());
151
174
#ifdef PADDLE_WITH_CUDA
152
175
if (platform::IsCUDAGraphCapturing ()) {
@@ -162,41 +185,20 @@ FetchList ProgramInterpreter::Run(
162
185
}
163
186
}
164
187
165
- FetchList ProgramInterpreter::Run (const std::vector<std::string>& feed_names,
166
- bool need_fetch) {
188
+ FetchList ProgramInterpreter::Run (
189
+ const std::vector<std::string>& feed_names,
190
+ const std::vector<phi::DenseTensor>& feed_tensors) {
167
191
SetDeviceId (place_);
168
192
CheckCUDAGraphBeforeRun (feed_names);
169
193
170
194
#ifdef PADDLE_WITH_DNNL
171
195
platform::AttachPointerHashToMKLDNNKey (this , place_);
172
196
#endif
173
197
174
- if (!is_build_) {
175
- LOG_FIRST_N (INFO, 1 ) << " New Executor is Running." ;
176
- paddle::framework::interpreter::BuildVariableScope (
177
- block_, execution_config_, &var_scope_);
198
+ bool is_build = is_build_;
199
+ Prepare (feed_names, feed_tensors, is_build);
178
200
179
- std::vector<paddle::framework::OpFuncNode> op_func_nodes;
180
- paddle::framework::interpreter::BuildOpFuncList (
181
- place_,
182
- block_,
183
- execution_config_.skip_gc_vars ,
184
- &op_func_nodes,
185
- &var_scope_,
186
- execution_config_,
187
- HasLocalScope (),
188
- static_build_);
189
- SetFeedVarsInplaceSkip (feed_names);
190
- // convert vec func_list to graph
191
- Convert (&op_func_nodes);
192
- UpdateSyncOpNum ();
193
- if (static_build_) {
194
- VLOG (4 ) << " RUN impl" ;
195
- RunImpl ();
196
- }
197
- is_build_ = true ;
198
- is_shared_results_build_ = true ;
199
- } else {
201
+ if (is_build) {
200
202
RunImpl ();
201
203
}
202
204
@@ -208,7 +210,7 @@ FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names,
208
210
Scope* inner_scope =
209
211
HasLocalScope () ? local_scope_ : var_scope_.GetMutableScope ();
210
212
auto * fetch_var = inner_scope->FindVar (interpreter::kFetchVarName );
211
- if (fetch_var && need_fetch ) {
213
+ if (fetch_var) {
212
214
auto fetch_list = std::move (*fetch_var->GetMutable <framework::FetchList>());
213
215
#ifdef PADDLE_WITH_CUDA
214
216
if (platform::IsCUDAGraphCapturing ()) {
0 commit comments