Skip to content

Commit 37a272e

Browse files
authored
add executor.prepare (#9022)
optimize executor.run
1 parent 30b7032 commit 37a272e

File tree

4 files changed

+116
-93
lines changed

4 files changed

+116
-93
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,8 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/framework/executor.h"
1616

17-
#include <set>
18-
19-
#include "gflags/gflags.h"
2017
#include "paddle/fluid/framework/channel.h"
2118
#include "paddle/fluid/framework/feed_fetch_method.h"
22-
#include "paddle/fluid/framework/feed_fetch_type.h"
2319
#include "paddle/fluid/framework/lod_rank_table.h"
2420
#include "paddle/fluid/framework/lod_tensor_array.h"
2521
#include "paddle/fluid/framework/op_registry.h"
@@ -40,14 +36,13 @@ namespace {
4036
int kProgramId = -1;
4137
} // namespace
4238

43-
struct ExecutorPrepareContext {
44-
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id)
45-
: prog_(prog), block_id_(block_id) {}
39+
ExecutorPrepareContext::ExecutorPrepareContext(
40+
const framework::ProgramDesc& prog, size_t block_id)
41+
: prog_(prog), block_id_(block_id) {}
4642

47-
const framework::ProgramDesc& prog_;
48-
size_t block_id_;
49-
std::vector<std::unique_ptr<OperatorBase>> ops_;
50-
};
43+
ExecutorPrepareContext::~ExecutorPrepareContext() {
44+
VLOG(5) << "destroy ExecutorPrepareContext";
45+
}
5146

5247
Executor::Executor(const platform::Place& place) : place_(place) {}
5348

@@ -101,9 +96,8 @@ static void CheckTensorNANOrInf(const std::string& name,
10196
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
10297
bool create_local_scope, bool create_vars) {
10398
platform::RecordBlock b(block_id);
104-
auto* ctx = Prepare(pdesc, block_id);
105-
RunPreparedContext(ctx, scope, create_local_scope, create_vars);
106-
delete ctx;
99+
auto ctx = Prepare(pdesc, block_id);
100+
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
107101
}
108102

109103
// Check whether the block already has feed operators and feed_holder.
@@ -274,15 +268,15 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
274268
}
275269
}
276270

277-
ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program,
278-
int block_id) {
271+
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
272+
const ProgramDesc& program, int block_id) {
279273
auto* ctx = new ExecutorPrepareContext(program, block_id);
280274
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
281275
auto& block = program.Block(block_id);
282276
for (auto& op_desc : block.AllOps()) {
283277
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
284278
}
285-
return ctx;
279+
return std::unique_ptr<ExecutorPrepareContext>(ctx);
286280
}
287281

288282
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,

paddle/fluid/framework/executor.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,16 @@ limitations under the License. */
2222

2323
namespace paddle {
2424
namespace framework {
25-
struct ExecutorPrepareContext;
25+
26+
struct ExecutorPrepareContext {
27+
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id);
28+
~ExecutorPrepareContext();
29+
30+
const framework::ProgramDesc& prog_;
31+
size_t block_id_;
32+
std::vector<std::unique_ptr<OperatorBase>> ops_;
33+
};
34+
2635
class Executor {
2736
public:
2837
// TODO(dzhwinter) : Do not rely on this function, it will be removed
@@ -47,8 +56,8 @@ class Executor {
4756
const std::string& feed_holder_name = "feed",
4857
const std::string& fetch_holder_name = "fetch");
4958

50-
static ExecutorPrepareContext* Prepare(const ProgramDesc& program,
51-
int block_id);
59+
static std::unique_ptr<ExecutorPrepareContext> Prepare(
60+
const ProgramDesc& program, int block_id);
5261

5362
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
5463
bool create_local_scope = true,

python/paddle/fluid/executor.py

Lines changed: 93 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,77 @@ def parselod(data):
235235
tensor.set_lod(lod)
236236
return tensor
237237

238+
def _get_program_cache(self, program_cache_key):
239+
return self.program_caches.get(program_cache_key, None)
240+
241+
def _add_program_cache(self, program_cache_key, program):
242+
self.program_caches[program_cache_key] = program
243+
244+
def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name,
245+
fetch_var_name):
246+
tmp_program = program.clone()
247+
248+
global_block = tmp_program.global_block()
249+
250+
if feed_var_name in global_block.vars:
251+
feed_var = global_block.var(feed_var_name)
252+
else:
253+
feed_var = global_block.create_var(
254+
name=feed_var_name,
255+
type=core.VarDesc.VarType.FEED_MINIBATCH,
256+
persistable=True)
257+
258+
if fetch_var_name in global_block.vars:
259+
fetch_var = global_block.var(fetch_var_name)
260+
else:
261+
fetch_var = global_block.create_var(
262+
name=fetch_var_name,
263+
type=core.VarDesc.VarType.FETCH_LIST,
264+
persistable=True)
265+
266+
# prepend feed operators
267+
if not has_feed_operators(global_block, feed, feed_var_name):
268+
for i, name in enumerate(feed):
269+
out = global_block.var(name)
270+
global_block.prepend_op(
271+
type='feed',
272+
inputs={'X': [feed_var]},
273+
outputs={'Out': [out]},
274+
attrs={'col': i})
275+
276+
# append fetch_operators
277+
if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
278+
for i, var in enumerate(fetch_list):
279+
assert isinstance(var, Variable) or isinstance(var, str), (
280+
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
281+
global_block.append_op(
282+
type='fetch',
283+
inputs={'X': [var]},
284+
outputs={'Out': [fetch_var]},
285+
attrs={'col': i})
286+
287+
return tmp_program
288+
289+
def _feed_data(self, program, feed, feed_var_name, scope):
290+
# feed var to framework
291+
for op in program.global_block().ops:
292+
if op.desc.type() == 'feed':
293+
feed_target_name = op.desc.output('Out')[0]
294+
cur_feed = feed[feed_target_name]
295+
if not isinstance(cur_feed, core.LoDTensor):
296+
cur_feed = self.aslodtensor(cur_feed)
297+
idx = op.desc.attr('col')
298+
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
299+
else:
300+
break
301+
302+
def _fetch_data(self, fetch_list, fetch_var_name, scope):
303+
outs = [
304+
core.get_fetch_variable(scope, fetch_var_name, i)
305+
for i in xrange(len(fetch_list))
306+
]
307+
return outs
308+
238309
def run(self,
239310
program=None,
240311
feed=None,
@@ -268,7 +339,6 @@ def run(self,
268339
raise TypeError("feed should be a map")
269340
if fetch_list is None:
270341
fetch_list = []
271-
272342
if program is None:
273343
program = default_main_program()
274344

@@ -278,79 +348,30 @@ def run(self,
278348
if scope is None:
279349
scope = global_scope()
280350

281-
program_cache = None
282-
program_cache_key = get_program_cache_key(feed, fetch_list)
283-
351+
cache_key = get_program_cache_key(feed, fetch_list)
284352
if use_program_cache:
285-
# find program cache by cache_key
286-
program_cache = self.program_caches.get(program_cache_key, None)
287-
# TODO(qiao): Should check program_cache and program are exactly the same.
353+
cached_program = self._get_program_cache(cache_key)
354+
if cached_program is None:
355+
cached_program = self._add_feed_fetch_ops(
356+
program=program,
357+
feed=feed,
358+
fetch_list=fetch_list,
359+
feed_var_name=feed_var_name,
360+
fetch_var_name=fetch_var_name)
361+
self._add_program_cache(cache_key, cached_program)
362+
program = cached_program
288363
else:
289-
self.program_caches.pop(program_cache_key, None)
290-
291-
if program_cache is None:
292-
program_cache = program.clone()
293-
294-
if use_program_cache:
295-
self.program_caches[program_cache_key] = program_cache
296-
297-
global_block = program_cache.global_block()
298-
299-
if feed_var_name in global_block.vars:
300-
feed_var = global_block.var(feed_var_name)
301-
else:
302-
feed_var = global_block.create_var(
303-
name=feed_var_name,
304-
type=core.VarDesc.VarType.FEED_MINIBATCH,
305-
persistable=True)
306-
307-
if fetch_var_name in global_block.vars:
308-
fetch_var = global_block.var(fetch_var_name)
309-
else:
310-
fetch_var = global_block.create_var(
311-
name=fetch_var_name,
312-
type=core.VarDesc.VarType.FETCH_LIST,
313-
persistable=True)
314-
315-
# prepend feed operators
316-
if not has_feed_operators(global_block, feed, feed_var_name):
317-
for i, name in enumerate(feed):
318-
out = global_block.var(name)
319-
global_block.prepend_op(
320-
type='feed',
321-
inputs={'X': [feed_var]},
322-
outputs={'Out': [out]},
323-
attrs={'col': i})
324-
325-
# append fetch_operators
326-
if not has_fetch_operators(global_block, fetch_list,
327-
fetch_var_name):
328-
for i, var in enumerate(fetch_list):
329-
assert isinstance(var, Variable) or isinstance(var, str), (
330-
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
331-
global_block.append_op(
332-
type='fetch',
333-
inputs={'X': [var]},
334-
outputs={'Out': [fetch_var]},
335-
attrs={'col': i})
336-
337-
# feed var to framework
338-
for op in program_cache.global_block().ops:
339-
if op.desc.type() == 'feed':
340-
feed_target_name = op.desc.output('Out')[0]
341-
cur_feed = feed[feed_target_name]
342-
if not isinstance(cur_feed, core.LoDTensor):
343-
cur_feed = self.aslodtensor(cur_feed)
344-
idx = op.desc.attr('col')
345-
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
346-
else:
347-
break
348-
349-
self.executor.run(program_cache.desc, scope, 0, True, True)
350-
outs = [
351-
core.get_fetch_variable(scope, fetch_var_name, i)
352-
for i in xrange(len(fetch_list))
353-
]
364+
self.program_caches.pop(cache_key, None)
365+
program = self._add_feed_fetch_ops(
366+
program=program,
367+
feed=feed,
368+
fetch_list=fetch_list,
369+
feed_var_name=feed_var_name,
370+
fetch_var_name=fetch_var_name)
371+
372+
self._feed_data(program, feed, feed_var_name, scope)
373+
self.executor.run(program.desc, scope, 0, True, True)
374+
outs = self._fetch_data(fetch_list, fetch_var_name, scope)
354375
if return_numpy:
355376
outs = as_numpy(outs)
356377
return outs

python/paddle/fluid/tests/unittests/test_executor_and_mul.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import numpy
1818
import paddle.fluid.core as core
19-
2019
from paddle.fluid.executor import Executor
2120
from paddle.fluid.layers import mul, data
2221

0 commit comments

Comments
 (0)