Skip to content

Commit 78415f3

Browse files
authored
Merge pull request #12838 from panyx0718/infer
speed up while_op
2 parents bc4f537 + a2c0e52 commit 78415f3

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

paddle/fluid/operators/while_op.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ class WhileOp : public framework::OperatorBase {
5757

5858
PADDLE_ENFORCE(platform::is_cpu_place(cond.place()),
5959
"Condition of while op must in CPU memory.");
60+
61+
auto ctx = executor.Prepare(*program, block->ID());
6062
while (cond.data<bool>()[0]) {
6163
auto &current_scope = scope.NewScope();
6264
step_scopes->push_back(&current_scope);
63-
64-
executor.Run(*program, &current_scope, block->ID(),
65-
false /*create_local_scope*/);
65+
executor.RunPreparedContext(ctx.get(), &current_scope, false);
6666
}
6767
}
6868
};
@@ -109,6 +109,7 @@ class WhileGradOp : public framework::OperatorBase {
109109
framework::Executor executor(dev_place);
110110
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
111111
auto *program = block->Program();
112+
auto ctx = executor.Prepare(*program, block->ID());
112113

113114
auto *step_scopes =
114115
scope.FindVar(Input(kStepScopes))->GetMutable<StepScopeVar>();
@@ -161,8 +162,7 @@ class WhileGradOp : public framework::OperatorBase {
161162
}
162163
}
163164
}
164-
165-
executor.Run(*program, *cur_scope_iter, block->ID(), false);
165+
executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false);
166166

167167
auto &pg_names = Outputs(kXGRAD);
168168
auto &p_names = Inputs(kX);

0 commit comments

Comments
 (0)