Skip to content

Commit 43d09a1

Browse files
committed
Extract Prepare from Executor
1 parent f7e9fe5 commit 43d09a1

File tree

2 files changed

+98
-70
lines changed

2 files changed

+98
-70
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 88 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ DEFINE_bool(check_nan_inf, false,
3434
namespace paddle {
3535
namespace framework {
3636

37+
struct ExecutorPrepareContext {
38+
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id)
39+
: prog_(prog), block_id_(block_id) {}
40+
41+
framework::ProgramDesc prog_;
42+
size_t block_id_;
43+
std::vector<std::unique_ptr<OperatorBase>> ops_;
44+
};
45+
3746
Executor::Executor(const platform::Place& place) : place_(place) {}
3847

3948
static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
@@ -85,73 +94,9 @@ static void CheckTensorNANOrInf(const std::string& name,
8594

8695
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
8796
bool create_local_scope, bool create_vars) {
88-
// TODO(tonyyang-svail):
89-
// - only runs on the first device (i.e. no interdevice communication)
90-
// - will change to use multiple blocks for RNN op and Cond Op
91-
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), pdesc.Size());
92-
auto& block = pdesc.Block(block_id);
93-
94-
Scope* local_scope = scope;
95-
if (create_vars) {
96-
if (create_local_scope) {
97-
local_scope = &scope->NewScope();
98-
for (auto& var : block.AllVars()) {
99-
if (var->Name() == framework::kEmptyVarName) {
100-
continue;
101-
}
102-
103-
if (var->Persistable()) {
104-
auto* ptr = scope->Var(var->Name());
105-
CreateTensor(ptr, var->GetType());
106-
VLOG(3) << "Create Variable " << var->Name()
107-
<< " global, which pointer is " << ptr;
108-
} else {
109-
auto* ptr = local_scope->Var(var->Name());
110-
CreateTensor(ptr, var->GetType());
111-
VLOG(3) << "Create Variable " << var->Name()
112-
<< " locally, which pointer is " << ptr;
113-
}
114-
}
115-
} else {
116-
for (auto& var : block.AllVars()) {
117-
auto* ptr = local_scope->Var(var->Name());
118-
CreateTensor(ptr, var->GetType());
119-
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
120-
<< ptr;
121-
}
122-
} // if (create_local_scope)
123-
} // if (create_vars)
124-
125-
for (auto& op_desc : block.AllOps()) {
126-
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
127-
128-
VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
129-
op->Run(*local_scope, place_);
130-
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
131-
132-
if (FLAGS_benchmark) {
133-
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
134-
<< memory::memory_usage(place_);
135-
}
136-
if (FLAGS_check_nan_inf) {
137-
for (auto& vname : op->OutputVars(true)) {
138-
auto* var = local_scope->FindVar(vname);
139-
if (var == nullptr) continue;
140-
if (var->IsType<framework::LoDTensor>()) {
141-
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
142-
}
143-
}
144-
}
145-
}
146-
if (create_vars && create_local_scope) {
147-
scope->DeleteScope(local_scope);
148-
}
149-
if (FLAGS_benchmark) {
150-
VLOG(2) << "-------------------------------------------------------";
151-
VLOG(2) << "Memory used after deleting local scope: "
152-
<< memory::memory_usage(place_);
153-
VLOG(2) << "-------------------------------------------------------";
154-
}
97+
auto* ctx = Prepare(pdesc, block_id);
98+
RunPreparedContext(ctx, scope, create_local_scope, create_vars);
99+
delete ctx;
155100
}
156101

157102
// Check whether the block already has feed operators and feed_holder.
@@ -313,5 +258,81 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
313258
delete copy_program;
314259
}
315260

261+
ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program,
262+
int block_id) {
263+
auto* ctx = new ExecutorPrepareContext(program, block_id);
264+
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
265+
auto& block = program.Block(block_id);
266+
for (auto& op_desc : block.AllOps()) {
267+
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
268+
}
269+
return ctx;
270+
}
271+
272+
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
273+
bool create_local_scope, bool create_vars) {
274+
auto& block = ctx->prog_.Block(ctx->block_id_);
275+
276+
Scope* local_scope = scope;
277+
if (create_vars) {
278+
if (create_local_scope) {
279+
local_scope = &scope->NewScope();
280+
for (auto& var : block.AllVars()) {
281+
if (var->Name() == framework::kEmptyVarName) {
282+
continue;
283+
}
284+
285+
if (var->Persistable()) {
286+
auto* ptr = scope->Var(var->Name());
287+
CreateTensor(ptr, var->GetType());
288+
VLOG(3) << "Create Variable " << var->Name()
289+
<< " global, which pointer is " << ptr;
290+
} else {
291+
auto* ptr = local_scope->Var(var->Name());
292+
CreateTensor(ptr, var->GetType());
293+
VLOG(3) << "Create Variable " << var->Name()
294+
<< " locally, which pointer is " << ptr;
295+
}
296+
}
297+
} else {
298+
for (auto& var : block.AllVars()) {
299+
auto* ptr = local_scope->Var(var->Name());
300+
CreateTensor(ptr, var->GetType());
301+
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
302+
<< ptr;
303+
}
304+
} // if (create_local_scope)
305+
} // if (create_vars)
306+
307+
for (auto& op : ctx->ops_) {
308+
VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
309+
op->Run(*local_scope, place_);
310+
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
311+
312+
if (FLAGS_benchmark) {
313+
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
314+
<< memory::memory_usage(place_);
315+
}
316+
if (FLAGS_check_nan_inf) {
317+
for (auto& vname : op->OutputVars(true)) {
318+
auto* var = local_scope->FindVar(vname);
319+
if (var == nullptr) continue;
320+
if (var->IsType<framework::LoDTensor>()) {
321+
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
322+
}
323+
}
324+
}
325+
}
326+
if (create_vars && create_local_scope) {
327+
scope->DeleteScope(local_scope);
328+
}
329+
if (FLAGS_benchmark) {
330+
VLOG(2) << "-------------------------------------------------------";
331+
VLOG(2) << "Memory used after deleting local scope: "
332+
<< memory::memory_usage(place_);
333+
VLOG(2) << "-------------------------------------------------------";
334+
}
335+
}
336+
316337
} // namespace framework
317338
} // namespace paddle

paddle/fluid/framework/executor.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ limitations under the License. */
2222

2323
namespace paddle {
2424
namespace framework {
25-
25+
struct ExecutorPrepareContext;
2626
class Executor {
2727
public:
2828
// TODO(dzhwinter) : Do not rely on this function, it will be removed
@@ -38,15 +38,22 @@ class Executor {
3838
* ProgramDesc
3939
* Scope
4040
*/
41-
void Run(const ProgramDesc&, Scope*, int, bool create_local_scope = true,
42-
bool create_vars = true);
41+
void Run(const ProgramDesc& prog, Scope* scope, int block_id,
42+
bool create_local_scope = true, bool create_vars = true);
4343

4444
void Run(const ProgramDesc& program, Scope* scope,
4545
std::map<std::string, const LoDTensor*>& feed_targets,
4646
std::map<std::string, LoDTensor*>& fetch_targets,
4747
const std::string& feed_holder_name = "feed",
4848
const std::string& fetch_holder_name = "fetch");
4949

50+
static ExecutorPrepareContext* Prepare(const ProgramDesc& program,
51+
int block_id);
52+
53+
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
54+
bool create_local_scope = true,
55+
bool create_vars = true);
56+
5057
private:
5158
const platform::Place place_;
5259
};

0 commit comments

Comments
 (0)