Skip to content

Commit 4977d99

Browse files
committed
add program cache for executor
1 parent 0165421 commit 4977d99

File tree

1 file changed

+55
-47
lines changed

1 file changed

+55
-47
lines changed

python/paddle/fluid/executor.py

Lines changed: 55 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def __init__(self, places):
177177
# TODO(dzhwinter) : only use the first place
178178
self.executor = core.Executor(act_places[0])
179179
self.places = places
180+
self.program_caches = dict()
180181

181182
def aslodtensor(self, data):
182183
def accumulate(data):
@@ -240,56 +241,63 @@ def run(self,
240241
if scope is None:
241242
scope = global_scope()
242243

243-
program = program.clone()
244-
global_block = program.global_block()
244+
program_cache_key = str(feed.keys() + fetch_list)
245+
program_cache = self.program_caches.get(program_cache_key, None)
245246

246-
if feed_var_name in global_block.vars:
247-
feed_var = global_block.var(feed_var_name)
248-
else:
249-
feed_var = global_block.create_var(
250-
name=feed_var_name,
251-
type=core.VarDesc.VarType.FEED_MINIBATCH,
252-
persistable=True)
247+
if program_cache is None:
248+
program_cache = program.clone()
249+
self.program_caches[program_cache_key] = program_cache
253250

254-
if fetch_var_name in global_block.vars:
255-
fetch_var = global_block.var(fetch_var_name)
256-
else:
257-
fetch_var = global_block.create_var(
258-
name=fetch_var_name,
259-
type=core.VarDesc.VarType.FETCH_LIST,
260-
persistable=True)
261-
262-
if not has_feed_operators(global_block, feed, feed_var_name):
263-
for i, name in enumerate(feed):
264-
out = global_block.var(name)
265-
global_block.prepend_op(
266-
type='feed',
267-
inputs={'X': [feed_var]},
268-
outputs={'Out': [out]},
269-
attrs={'col': i})
270-
271-
for op in global_block.ops:
272-
if op.desc.type() == 'feed':
273-
feed_target_name = op.desc.output('Out')[0]
274-
cur_feed = feed[feed_target_name]
275-
if not isinstance(cur_feed, core.LoDTensor):
276-
cur_feed = self.aslodtensor(cur_feed)
277-
idx = op.desc.attr('col')
278-
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
251+
global_block = program_cache.global_block()
252+
253+
if feed_var_name in global_block.vars:
254+
feed_var = global_block.var(feed_var_name)
255+
else:
256+
feed_var = global_block.create_var(
257+
name=feed_var_name,
258+
type=core.VarDesc.VarType.FEED_MINIBATCH,
259+
persistable=True)
260+
261+
if fetch_var_name in global_block.vars:
262+
fetch_var = global_block.var(fetch_var_name)
279263
else:
280-
break
281-
282-
if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
283-
for i, var in enumerate(fetch_list):
284-
assert isinstance(var, Variable) or isinstance(var, str), (
285-
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
286-
global_block.append_op(
287-
type='fetch',
288-
inputs={'X': [var]},
289-
outputs={'Out': [fetch_var]},
290-
attrs={'col': i})
291-
292-
self.executor.run(program.desc, scope, 0, True, True)
264+
fetch_var = global_block.create_var(
265+
name=fetch_var_name,
266+
type=core.VarDesc.VarType.FETCH_LIST,
267+
persistable=True)
268+
269+
if not has_feed_operators(global_block, feed, feed_var_name):
270+
for i, name in enumerate(feed):
271+
out = global_block.var(name)
272+
global_block.prepend_op(
273+
type='feed',
274+
inputs={'X': [feed_var]},
275+
outputs={'Out': [out]},
276+
attrs={'col': i})
277+
278+
for op in global_block.ops:
279+
if op.desc.type() == 'feed':
280+
feed_target_name = op.desc.output('Out')[0]
281+
cur_feed = feed[feed_target_name]
282+
if not isinstance(cur_feed, core.LoDTensor):
283+
cur_feed = self.aslodtensor(cur_feed)
284+
idx = op.desc.attr('col')
285+
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
286+
else:
287+
break
288+
289+
if not has_fetch_operators(global_block, fetch_list,
290+
fetch_var_name):
291+
for i, var in enumerate(fetch_list):
292+
assert isinstance(var, Variable) or isinstance(var, str), (
293+
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
294+
global_block.append_op(
295+
type='fetch',
296+
inputs={'X': [var]},
297+
outputs={'Out': [fetch_var]},
298+
attrs={'col': i})
299+
300+
self.executor.run(program_cache.desc, scope, 0, True, True)
293301
outs = [
294302
core.get_fetch_variable(scope, fetch_var_name, i)
295303
for i in xrange(len(fetch_list))

0 commit comments

Comments
 (0)