@@ -129,13 +129,15 @@ static bool has_feed_operators(
129
129
feed_count, feed_targets.size (),
130
130
" The number of feed operators should match 'feed_targets'" );
131
131
132
- // When feed operator are present, so should be feed_holder
133
- auto var = block.FindVar (feed_holder_name);
134
- PADDLE_ENFORCE_NOT_NULL (var, " Block should already have a '%s' variable" ,
135
- feed_holder_name);
136
- PADDLE_ENFORCE_EQ (var->GetType (), proto::VarType::FEED_MINIBATCH,
137
- " '%s' variable should be 'FEED_MINIBATCH' type" ,
138
- feed_holder_name);
132
+ if (!feed_holder_name.empty ()) {
133
+ // When feed operator are present, so should be feed_holder
134
+ auto var = block.FindVar (feed_holder_name);
135
+ PADDLE_ENFORCE_NOT_NULL (var, " Block should already have a '%s' variable" ,
136
+ feed_holder_name);
137
+ PADDLE_ENFORCE_EQ (var->GetType (), proto::VarType::FEED_MINIBATCH,
138
+ " '%s' variable should be 'FEED_MINIBATCH' type" ,
139
+ feed_holder_name);
140
+ }
139
141
}
140
142
141
143
return feed_count > 0 ;
@@ -169,13 +171,15 @@ static bool has_fetch_operators(
169
171
fetch_count, fetch_targets.size (),
170
172
" The number of fetch operators should match 'fetch_targets'" );
171
173
172
- // When fetch operator are present, so should be fetch_holder
173
- auto var = block.FindVar (fetch_holder_name);
174
- PADDLE_ENFORCE_NOT_NULL (var, " Block should already have a '%s' variable" ,
175
- fetch_holder_name);
176
- PADDLE_ENFORCE_EQ (var->GetType (), proto::VarType::FETCH_LIST,
177
- " '%s' variable should be 'FETCH_LIST' type" ,
178
- fetch_holder_name);
174
+ if (!fetch_holder_name.empty ()) {
175
+ // When fetch operator are present, so should be fetch_holder
176
+ auto var = block.FindVar (fetch_holder_name);
177
+ PADDLE_ENFORCE_NOT_NULL (var, " Block should already have a '%s' variable" ,
178
+ fetch_holder_name);
179
+ PADDLE_ENFORCE_EQ (var->GetType (), proto::VarType::FETCH_LIST,
180
+ " '%s' variable should be 'FETCH_LIST' type" ,
181
+ fetch_holder_name);
182
+ }
179
183
}
180
184
181
185
return fetch_count > 0 ;
@@ -222,16 +226,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
222
226
}
223
227
}
224
228
225
- // map the data of feed_targets to feed_holder
226
- for (auto * op : global_block->AllOps ()) {
227
- if (op->Type () == kFeedOpType ) {
228
- std::string feed_target_name = op->Output (" Out" )[0 ];
229
- int idx = boost::get<int >(op->GetAttr (" col" ));
230
- SetFeedVariable (scope, *feed_targets[feed_target_name], feed_holder_name,
231
- idx);
232
- }
233
- }
234
-
235
229
if (!has_fetch_ops) {
236
230
// create fetch_holder variable
237
231
auto * fetch_holder = global_block->Var (fetch_holder_name);
@@ -255,17 +249,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
255
249
}
256
250
}
257
251
258
- Run (*copy_program, scope, 0 , create_vars, create_vars);
259
-
260
- // obtain the data of fetch_targets from fetch_holder
261
- for (auto * op : global_block->AllOps ()) {
262
- if (op->Type () == kFetchOpType ) {
263
- std::string fetch_target_name = op->Input (" X" )[0 ];
264
- int idx = boost::get<int >(op->GetAttr (" col" ));
265
- *fetch_targets[fetch_target_name] =
266
- GetFetchVariable (*scope, fetch_holder_name, idx);
267
- }
268
- }
252
+ auto ctx = Prepare (*copy_program, 0 );
253
+ RunPreparedContext (ctx.get (), scope, feed_targets, fetch_targets,
254
+ feed_holder_name, fetch_holder_name, create_vars);
269
255
}
270
256
271
257
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare (
@@ -343,5 +329,43 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
343
329
}
344
330
}
345
331
332
+ void Executor::RunPreparedContext (
333
+ ExecutorPrepareContext* ctx, Scope* scope,
334
+ std::map<std::string, const LoDTensor*>& feed_targets,
335
+ std::map<std::string, LoDTensor*>& fetch_targets,
336
+ const std::string& feed_holder_name, const std::string& fetch_holder_name,
337
+ bool create_vars) {
338
+ auto & global_block = ctx->prog_ .Block (ctx->block_id_ );
339
+
340
+ // map the data of feed_targets to feed_holder
341
+ for (auto * op : global_block.AllOps ()) {
342
+ if (op->Type () == kFeedOpType ) {
343
+ std::string feed_target_name = op->Output (" Out" )[0 ];
344
+ PADDLE_ENFORCE (feed_targets.find (feed_target_name) != feed_targets.end (),
345
+ " Variable %s is not feeded." );
346
+
347
+ int idx = boost::get<int >(op->GetAttr (" col" ));
348
+ SetFeedVariable (scope, *feed_targets[feed_target_name], feed_holder_name,
349
+ idx);
350
+ }
351
+ }
352
+
353
+ RunPreparedContext (ctx, scope, create_vars, create_vars);
354
+
355
+ // obtain the data of fetch_targets from fetch_holder
356
+ for (auto * op : global_block.AllOps ()) {
357
+ if (op->Type () == kFetchOpType ) {
358
+ std::string fetch_target_name = op->Input (" X" )[0 ];
359
+ PADDLE_ENFORCE (
360
+ fetch_targets.find (fetch_target_name) != fetch_targets.end (),
361
+ " Variable %s is not fetched." );
362
+
363
+ int idx = boost::get<int >(op->GetAttr (" col" ));
364
+ *fetch_targets[fetch_target_name] =
365
+ GetFetchVariable (*scope, fetch_holder_name, idx);
366
+ }
367
+ }
368
+ }
369
+
346
370
} // namespace framework
347
371
} // namespace paddle
0 commit comments