@@ -226,7 +226,19 @@ def run(self,
226
226
feed_var_name = 'feed' ,
227
227
fetch_var_name = 'fetch' ,
228
228
scope = None ,
229
- return_numpy = True ):
229
+ return_numpy = True ,
230
+ use_program_cache = False ):
231
+ """
232
+ :param program: the program that need to run
233
+ :param feed: feed variable list
234
+ :param fetch_list: fetch variable list
235
+ :param feed_var_name: feed_var_name default to 'feed'
236
+ :param fetch_var_name: fetch_var_name default to 'fetch'
237
+ :param scope: the scope used to run this program, you can switch it to different scope.
238
+ :param return_numpy: convert the fetched tensor to numpy
239
+ :param use_program_cache: set use_program_cache to true if program not changed compare to the last step.
240
+ :return:
241
+ """
230
242
if feed is None :
231
243
feed = {}
232
244
if fetch_list is None :
@@ -244,7 +256,7 @@ def run(self,
244
256
program_cache_key = str (feed .keys () + fetch_list )
245
257
program_cache = self .program_caches .get (program_cache_key , None )
246
258
247
- if program_cache is None :
259
+ if program_cache is None or not use_program_cache :
248
260
program_cache = program .clone ()
249
261
self .program_caches [program_cache_key ] = program_cache
250
262
@@ -266,6 +278,7 @@ def run(self,
266
278
type = core .VarDesc .VarType .FETCH_LIST ,
267
279
persistable = True )
268
280
281
+ # prepend feed operators
269
282
if not has_feed_operators (global_block , feed , feed_var_name ):
270
283
for i , name in enumerate (feed ):
271
284
out = global_block .var (name )
@@ -275,6 +288,7 @@ def run(self,
275
288
outputs = {'Out' : [out ]},
276
289
attrs = {'col' : i })
277
290
291
+ # append fetch_operators
278
292
if not has_fetch_operators (global_block , fetch_list ,
279
293
fetch_var_name ):
280
294
for i , var in enumerate (fetch_list ):
0 commit comments