Skip to content

Commit 767acc6

Browse files
authored
Merge pull request #8744 from jacquesqiao/add-program-cache-for-executor
Add program cache for executor.py
2 parents 86263b2 + 5d9dbe1 commit 767acc6

File tree

2 files changed

+96
-39
lines changed

2 files changed

+96
-39
lines changed

python/paddle/fluid/executor.py

Lines changed: 95 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,22 @@ def fetch_var(name, scope=None, return_numpy=True):
163163
return tensor
164164

165165

166+
def get_program_cache_key(feed, fetch_list):
167+
feed_var_names = feed.keys()
168+
169+
def to_name_str(var):
170+
if isinstance(var, Variable):
171+
return var.desc.name()
172+
elif isinstance(var, str):
173+
return var
174+
else:
175+
raise TypeError(str(var) + " should be Variable or str")
176+
177+
fetch_var_names = map(to_name_str, fetch_list)
178+
179+
return str(feed_var_names + fetch_var_names)
180+
181+
166182
class Executor(object):
167183
def __init__(self, places):
168184
if not isinstance(places, list) and not isinstance(places, tuple):
@@ -177,6 +193,7 @@ def __init__(self, places):
177193
# TODO(dzhwinter) : only use the first place
178194
self.executor = core.Executor(act_places[0])
179195
self.places = places
196+
self.program_caches = dict()
180197

181198
def aslodtensor(self, data):
182199
def accumulate(data):
@@ -225,9 +242,30 @@ def run(self,
225242
feed_var_name='feed',
226243
fetch_var_name='fetch',
227244
scope=None,
228-
return_numpy=True):
245+
return_numpy=True,
246+
use_program_cache=False):
247+
""" Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
248+
249+
Python executor takes a program, add feed operators and fetch operators to this program according
250+
to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides
251+
the variables(or names) that user want to get after program run. Note: the executor will run all
252+
operators in the program but not only the operators dependent by the fetch_list
253+
254+
:param program: the program that need to run, if not provied, then default_main_program will be used.
255+
:param feed: feed variable map, e.g. {"image": ImageData, "label": LableData}
256+
:param fetch_list: a list of variable or variable names that user want to get, run will return them according
257+
to this list.
258+
:param feed_var_name: the name for the input variable of feed Operator.
259+
:param fetch_var_name: the name for the output variable of feed Operator.
260+
:param scope: the scope used to run this program, you can switch it to different scope. default is global_scope
261+
:param return_numpy: if convert the fetched tensor to numpy
262+
:param use_program_cache: set use_program_cache to true if program not changed compare to the last step.
263+
:return: result according to fetch_list.
264+
"""
229265
if feed is None:
230266
feed = {}
267+
if not isinstance(feed, dict):
268+
raise TypeError("feed should be a map")
231269
if fetch_list is None:
232270
fetch_list = []
233271

@@ -240,35 +278,64 @@ def run(self,
240278
if scope is None:
241279
scope = global_scope()
242280

243-
program = program.clone()
244-
global_block = program.global_block()
281+
program_cache = None
282+
program_cache_key = get_program_cache_key(feed, fetch_list)
245283

246-
if feed_var_name in global_block.vars:
247-
feed_var = global_block.var(feed_var_name)
284+
if use_program_cache:
285+
# find program cache by cache_key
286+
program_cache = self.program_caches.get(program_cache_key, None)
287+
# TODO(qiao): Should check program_cache and program are exactly the same.
248288
else:
249-
feed_var = global_block.create_var(
250-
name=feed_var_name,
251-
type=core.VarDesc.VarType.FEED_MINIBATCH,
252-
persistable=True)
289+
self.program_caches.pop(program_cache_key, None)
253290

254-
if fetch_var_name in global_block.vars:
255-
fetch_var = global_block.var(fetch_var_name)
256-
else:
257-
fetch_var = global_block.create_var(
258-
name=fetch_var_name,
259-
type=core.VarDesc.VarType.FETCH_LIST,
260-
persistable=True)
261-
262-
if not has_feed_operators(global_block, feed, feed_var_name):
263-
for i, name in enumerate(feed):
264-
out = global_block.var(name)
265-
global_block.prepend_op(
266-
type='feed',
267-
inputs={'X': [feed_var]},
268-
outputs={'Out': [out]},
269-
attrs={'col': i})
270-
271-
for op in global_block.ops:
291+
if program_cache is None:
292+
program_cache = program.clone()
293+
294+
if use_program_cache:
295+
self.program_caches[program_cache_key] = program_cache
296+
297+
global_block = program_cache.global_block()
298+
299+
if feed_var_name in global_block.vars:
300+
feed_var = global_block.var(feed_var_name)
301+
else:
302+
feed_var = global_block.create_var(
303+
name=feed_var_name,
304+
type=core.VarDesc.VarType.FEED_MINIBATCH,
305+
persistable=True)
306+
307+
if fetch_var_name in global_block.vars:
308+
fetch_var = global_block.var(fetch_var_name)
309+
else:
310+
fetch_var = global_block.create_var(
311+
name=fetch_var_name,
312+
type=core.VarDesc.VarType.FETCH_LIST,
313+
persistable=True)
314+
315+
# prepend feed operators
316+
if not has_feed_operators(global_block, feed, feed_var_name):
317+
for i, name in enumerate(feed):
318+
out = global_block.var(name)
319+
global_block.prepend_op(
320+
type='feed',
321+
inputs={'X': [feed_var]},
322+
outputs={'Out': [out]},
323+
attrs={'col': i})
324+
325+
# append fetch_operators
326+
if not has_fetch_operators(global_block, fetch_list,
327+
fetch_var_name):
328+
for i, var in enumerate(fetch_list):
329+
assert isinstance(var, Variable) or isinstance(var, str), (
330+
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
331+
global_block.append_op(
332+
type='fetch',
333+
inputs={'X': [var]},
334+
outputs={'Out': [fetch_var]},
335+
attrs={'col': i})
336+
337+
# feed var to framework
338+
for op in program_cache.global_block().ops:
272339
if op.desc.type() == 'feed':
273340
feed_target_name = op.desc.output('Out')[0]
274341
cur_feed = feed[feed_target_name]
@@ -279,17 +346,7 @@ def run(self,
279346
else:
280347
break
281348

282-
if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
283-
for i, var in enumerate(fetch_list):
284-
assert isinstance(var, Variable) or isinstance(var, str), (
285-
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
286-
global_block.append_op(
287-
type='fetch',
288-
inputs={'X': [var]},
289-
outputs={'Out': [fetch_var]},
290-
attrs={'col': i})
291-
292-
self.executor.run(program.desc, scope, 0, True, True)
349+
self.executor.run(program_cache.desc, scope, 0, True, True)
293350
outs = [
294351
core.get_fetch_variable(scope, fetch_var_name, i)
295352
for i in xrange(len(fetch_list))

python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def check_decay(self, python_decay_fn, fluid_decay_fn, kwargs):
8989
exe.run(fluid.default_startup_program())
9090
for step in range(10):
9191
lr_val, = exe.run(fluid.default_main_program(),
92-
feed=[],
92+
feed={},
9393
fetch_list=[decayed_lr])
9494
python_decayed_lr = python_decay_fn(
9595
global_step=float(step), **kwargs)

0 commit comments

Comments
 (0)