Skip to content

Commit a8fd6d5

Browse files
committed
add use_program_cache to executor.run
1 parent 0876fc1 commit a8fd6d5

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

python/paddle/fluid/executor.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,19 @@ def run(self,
226226
feed_var_name='feed',
227227
fetch_var_name='fetch',
228228
scope=None,
229-
return_numpy=True):
229+
return_numpy=True,
230+
use_program_cache=False):
231+
"""
232+
:param program: the program that need to run
233+
:param feed: feed variable list
234+
:param fetch_list: fetch variable list
235+
:param feed_var_name: feed_var_name default to 'feed'
236+
:param fetch_var_name: fetch_var_name default to 'fetch'
237+
:param scope: the scope used to run this program, you can switch it to different scope.
238+
:param return_numpy: convert the fetched tensor to numpy
239+
:param use_program_cache: set use_program_cache to true if program not changed compare to the last step.
240+
:return:
241+
"""
230242
if feed is None:
231243
feed = {}
232244
if fetch_list is None:
@@ -244,7 +256,7 @@ def run(self,
244256
program_cache_key = str(feed.keys() + fetch_list)
245257
program_cache = self.program_caches.get(program_cache_key, None)
246258

247-
if program_cache is None:
259+
if program_cache is None or not use_program_cache:
248260
program_cache = program.clone()
249261
self.program_caches[program_cache_key] = program_cache
250262

@@ -266,6 +278,7 @@ def run(self,
266278
type=core.VarDesc.VarType.FETCH_LIST,
267279
persistable=True)
268280

281+
# prepend feed operators
269282
if not has_feed_operators(global_block, feed, feed_var_name):
270283
for i, name in enumerate(feed):
271284
out = global_block.var(name)
@@ -275,6 +288,7 @@ def run(self,
275288
outputs={'Out': [out]},
276289
attrs={'col': i})
277290

291+
# append fetch_operators
278292
if not has_fetch_operators(global_block, fetch_list,
279293
fetch_var_name):
280294
for i, var in enumerate(fetch_list):

0 commit comments

Comments
 (0)