@@ -57,12 +57,12 @@ class WhileOp : public framework::OperatorBase {
57
57
58
58
PADDLE_ENFORCE (platform::is_cpu_place (cond.place ()),
59
59
" Condition of while op must in CPU memory." );
60
+
61
+ auto ctx = executor.Prepare (*program, block->ID ());
60
62
while (cond.data <bool >()[0 ]) {
61
63
auto ¤t_scope = scope.NewScope ();
62
64
step_scopes->push_back (¤t_scope);
63
-
64
- executor.Run (*program, ¤t_scope, block->ID (),
65
- false /* create_local_scope*/ );
65
+ executor.RunPreparedContext (ctx.get (), ¤t_scope, false );
66
66
}
67
67
}
68
68
};
@@ -109,6 +109,7 @@ class WhileGradOp : public framework::OperatorBase {
109
109
framework::Executor executor (dev_place);
110
110
auto *block = Attr<framework::BlockDesc *>(kStepBlock );
111
111
auto *program = block->Program ();
112
+ auto ctx = executor.Prepare (*program, block->ID ());
112
113
113
114
auto *step_scopes =
114
115
scope.FindVar (Input (kStepScopes ))->GetMutable <StepScopeVar>();
@@ -161,8 +162,7 @@ class WhileGradOp : public framework::OperatorBase {
161
162
}
162
163
}
163
164
}
164
-
165
- executor.Run (*program, *cur_scope_iter, block->ID (), false );
165
+ executor.RunPreparedContext (ctx.get (), *cur_scope_iter, false );
166
166
167
167
auto &pg_names = Outputs (kXGRAD );
168
168
auto &p_names = Inputs (kX );
0 commit comments