@@ -34,6 +34,15 @@ DEFINE_bool(check_nan_inf, false,
34
34
namespace paddle {
35
35
namespace framework {
36
36
37
+ struct ExecutorPrepareContext {
38
+ ExecutorPrepareContext (const framework::ProgramDesc& prog, size_t block_id)
39
+ : prog_(prog), block_id_(block_id) {}
40
+
41
+ framework::ProgramDesc prog_;
42
+ size_t block_id_;
43
+ std::vector<std::unique_ptr<OperatorBase>> ops_;
44
+ };
45
+
37
46
Executor::Executor (const platform::Place& place) : place_(place) {}
38
47
39
48
static void CreateTensor (Variable* var, proto::VarType::Type var_type) {
@@ -85,73 +94,9 @@ static void CheckTensorNANOrInf(const std::string& name,
85
94
86
95
void Executor::Run (const ProgramDesc& pdesc, Scope* scope, int block_id,
87
96
bool create_local_scope, bool create_vars) {
88
- // TODO(tonyyang-svail):
89
- // - only runs on the first device (i.e. no interdevice communication)
90
- // - will change to use multiple blocks for RNN op and Cond Op
91
- PADDLE_ENFORCE_LT (static_cast <size_t >(block_id), pdesc.Size ());
92
- auto & block = pdesc.Block (block_id);
93
-
94
- Scope* local_scope = scope;
95
- if (create_vars) {
96
- if (create_local_scope) {
97
- local_scope = &scope->NewScope ();
98
- for (auto & var : block.AllVars ()) {
99
- if (var->Name () == framework::kEmptyVarName ) {
100
- continue ;
101
- }
102
-
103
- if (var->Persistable ()) {
104
- auto * ptr = scope->Var (var->Name ());
105
- CreateTensor (ptr, var->GetType ());
106
- VLOG (3 ) << " Create Variable " << var->Name ()
107
- << " global, which pointer is " << ptr;
108
- } else {
109
- auto * ptr = local_scope->Var (var->Name ());
110
- CreateTensor (ptr, var->GetType ());
111
- VLOG (3 ) << " Create Variable " << var->Name ()
112
- << " locally, which pointer is " << ptr;
113
- }
114
- }
115
- } else {
116
- for (auto & var : block.AllVars ()) {
117
- auto * ptr = local_scope->Var (var->Name ());
118
- CreateTensor (ptr, var->GetType ());
119
- VLOG (3 ) << " Create variable " << var->Name () << " , which pointer is "
120
- << ptr;
121
- }
122
- } // if (create_local_scope)
123
- } // if (create_vars)
124
-
125
- for (auto & op_desc : block.AllOps ()) {
126
- auto op = paddle::framework::OpRegistry::CreateOp (*op_desc);
127
-
128
- VLOG (4 ) << place_ << " " << op->DebugStringEx (local_scope);
129
- op->Run (*local_scope, place_);
130
- VLOG (3 ) << place_ << " " << op->DebugStringEx (local_scope);
131
-
132
- if (FLAGS_benchmark) {
133
- VLOG (2 ) << " Memory used after operator " + op->Type () + " running: "
134
- << memory::memory_usage (place_);
135
- }
136
- if (FLAGS_check_nan_inf) {
137
- for (auto & vname : op->OutputVars (true )) {
138
- auto * var = local_scope->FindVar (vname);
139
- if (var == nullptr ) continue ;
140
- if (var->IsType <framework::LoDTensor>()) {
141
- CheckTensorNANOrInf (vname, var->Get <framework::LoDTensor>());
142
- }
143
- }
144
- }
145
- }
146
- if (create_vars && create_local_scope) {
147
- scope->DeleteScope (local_scope);
148
- }
149
- if (FLAGS_benchmark) {
150
- VLOG (2 ) << " -------------------------------------------------------" ;
151
- VLOG (2 ) << " Memory used after deleting local scope: "
152
- << memory::memory_usage (place_);
153
- VLOG (2 ) << " -------------------------------------------------------" ;
154
- }
97
+ auto * ctx = Prepare (pdesc, block_id);
98
+ RunPreparedContext (ctx, scope, create_local_scope, create_vars);
99
+ delete ctx;
155
100
}
156
101
157
102
// Check whether the block already has feed operators and feed_holder.
@@ -313,5 +258,81 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
313
258
delete copy_program;
314
259
}
315
260
261
+ ExecutorPrepareContext* Executor::Prepare (const ProgramDesc& program,
262
+ int block_id) {
263
+ auto * ctx = new ExecutorPrepareContext (program, block_id);
264
+ PADDLE_ENFORCE_LT (static_cast <size_t >(block_id), program.Size ());
265
+ auto & block = program.Block (block_id);
266
+ for (auto & op_desc : block.AllOps ()) {
267
+ ctx->ops_ .push_back (OpRegistry::CreateOp (*op_desc));
268
+ }
269
+ return ctx;
270
+ }
271
+
272
+ void Executor::RunPreparedContext (ExecutorPrepareContext* ctx, Scope* scope,
273
+ bool create_local_scope, bool create_vars) {
274
+ auto & block = ctx->prog_ .Block (ctx->block_id_ );
275
+
276
+ Scope* local_scope = scope;
277
+ if (create_vars) {
278
+ if (create_local_scope) {
279
+ local_scope = &scope->NewScope ();
280
+ for (auto & var : block.AllVars ()) {
281
+ if (var->Name () == framework::kEmptyVarName ) {
282
+ continue ;
283
+ }
284
+
285
+ if (var->Persistable ()) {
286
+ auto * ptr = scope->Var (var->Name ());
287
+ CreateTensor (ptr, var->GetType ());
288
+ VLOG (3 ) << " Create Variable " << var->Name ()
289
+ << " global, which pointer is " << ptr;
290
+ } else {
291
+ auto * ptr = local_scope->Var (var->Name ());
292
+ CreateTensor (ptr, var->GetType ());
293
+ VLOG (3 ) << " Create Variable " << var->Name ()
294
+ << " locally, which pointer is " << ptr;
295
+ }
296
+ }
297
+ } else {
298
+ for (auto & var : block.AllVars ()) {
299
+ auto * ptr = local_scope->Var (var->Name ());
300
+ CreateTensor (ptr, var->GetType ());
301
+ VLOG (3 ) << " Create variable " << var->Name () << " , which pointer is "
302
+ << ptr;
303
+ }
304
+ } // if (create_local_scope)
305
+ } // if (create_vars)
306
+
307
+ for (auto & op : ctx->ops_ ) {
308
+ VLOG (4 ) << place_ << " " << op->DebugStringEx (local_scope);
309
+ op->Run (*local_scope, place_);
310
+ VLOG (3 ) << place_ << " " << op->DebugStringEx (local_scope);
311
+
312
+ if (FLAGS_benchmark) {
313
+ VLOG (2 ) << " Memory used after operator " + op->Type () + " running: "
314
+ << memory::memory_usage (place_);
315
+ }
316
+ if (FLAGS_check_nan_inf) {
317
+ for (auto & vname : op->OutputVars (true )) {
318
+ auto * var = local_scope->FindVar (vname);
319
+ if (var == nullptr ) continue ;
320
+ if (var->IsType <framework::LoDTensor>()) {
321
+ CheckTensorNANOrInf (vname, var->Get <framework::LoDTensor>());
322
+ }
323
+ }
324
+ }
325
+ }
326
+ if (create_vars && create_local_scope) {
327
+ scope->DeleteScope (local_scope);
328
+ }
329
+ if (FLAGS_benchmark) {
330
+ VLOG (2 ) << " -------------------------------------------------------" ;
331
+ VLOG (2 ) << " Memory used after deleting local scope: "
332
+ << memory::memory_usage (place_);
333
+ VLOG (2 ) << " -------------------------------------------------------" ;
334
+ }
335
+ }
336
+
316
337
} // namespace framework
317
338
} // namespace paddle
0 commit comments