Skip to content

Commit a9e826e

Browse files
committed
Add the check of has_feed/fetch_operators back.
1 parent 7b40f7c commit a9e826e

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -352,13 +352,17 @@ void Executor::RunPreparedContext(
352352
bool create_vars) {
353353
auto& global_block = ctx->prog_.Block(ctx->block_id_);
354354

355+
PADDLE_ENFORCE(
356+
has_feed_operators(global_block, feed_targets, feed_holder_name),
357+
"Program in ExecutorPrepareContext should has feed_ops.");
358+
PADDLE_ENFORCE(
359+
has_fetch_operators(global_block, fetch_targets, fetch_holder_name),
360+
"Program in the prepared context should has fetch_ops.");
361+
355362
// map the data of feed_targets to feed_holder
356363
for (auto* op : global_block.AllOps()) {
357364
if (op->Type() == kFeedOpType) {
358365
std::string feed_target_name = op->Output("Out")[0];
359-
PADDLE_ENFORCE(feed_targets.find(feed_target_name) != feed_targets.end(),
360-
"Variable %s is not feeded.");
361-
362366
int idx = boost::get<int>(op->GetAttr("col"));
363367
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
364368
idx);
@@ -371,10 +375,6 @@ void Executor::RunPreparedContext(
371375
for (auto* op : global_block.AllOps()) {
372376
if (op->Type() == kFetchOpType) {
373377
std::string fetch_target_name = op->Input("X")[0];
374-
PADDLE_ENFORCE(
375-
fetch_targets.find(fetch_target_name) != fetch_targets.end(),
376-
"Variable %s is not fetched.");
377-
378378
int idx = boost::get<int>(op->GetAttr("col"));
379379
*fetch_targets[fetch_target_name] =
380380
GetFetchVariable(*scope, fetch_holder_name, idx);

0 commit comments

Comments
 (0)