@@ -163,6 +163,22 @@ def fetch_var(name, scope=None, return_numpy=True):
163
163
return tensor
164
164
165
165
166
+ def get_program_cache_key (feed , fetch_list ):
167
+ feed_var_names = feed .keys ()
168
+
169
+ def to_name_str (var ):
170
+ if isinstance (var , Variable ):
171
+ return var .desc .name ()
172
+ elif isinstance (var , str ):
173
+ return var
174
+ else :
175
+ raise TypeError (str (var ) + " should be Variable or str" )
176
+
177
+ fetch_var_names = map (to_name_str , fetch_list )
178
+
179
+ return str (feed_var_names + fetch_var_names )
180
+
181
+
166
182
class Executor (object ):
167
183
def __init__ (self , places ):
168
184
if not isinstance (places , list ) and not isinstance (places , tuple ):
@@ -177,6 +193,7 @@ def __init__(self, places):
177
193
# TODO(dzhwinter) : only use the first place
178
194
self .executor = core .Executor (act_places [0 ])
179
195
self .places = places
196
+ self .program_caches = dict ()
180
197
181
198
def aslodtensor (self , data ):
182
199
def accumulate (data ):
@@ -225,9 +242,30 @@ def run(self,
225
242
feed_var_name = 'feed' ,
226
243
fetch_var_name = 'fetch' ,
227
244
scope = None ,
228
- return_numpy = True ):
245
+ return_numpy = True ,
246
+ use_program_cache = False ):
247
+ """ Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
248
+
249
+ Python executor takes a program, add feed operators and fetch operators to this program according
250
+ to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides
251
+ the variables(or names) that user want to get after program run. Note: the executor will run all
252
+ operators in the program but not only the operators dependent by the fetch_list
253
+
254
+ :param program: the program that need to run, if not provied, then default_main_program will be used.
255
+ :param feed: feed variable map, e.g. {"image": ImageData, "label": LableData}
256
+ :param fetch_list: a list of variable or variable names that user want to get, run will return them according
257
+ to this list.
258
+ :param feed_var_name: the name for the input variable of feed Operator.
259
+ :param fetch_var_name: the name for the output variable of feed Operator.
260
+ :param scope: the scope used to run this program, you can switch it to different scope. default is global_scope
261
+ :param return_numpy: if convert the fetched tensor to numpy
262
+ :param use_program_cache: set use_program_cache to true if program not changed compare to the last step.
263
+ :return: result according to fetch_list.
264
+ """
229
265
if feed is None :
230
266
feed = {}
267
+ if not isinstance (feed , dict ):
268
+ raise TypeError ("feed should be a map" )
231
269
if fetch_list is None :
232
270
fetch_list = []
233
271
@@ -240,35 +278,64 @@ def run(self,
240
278
if scope is None :
241
279
scope = global_scope ()
242
280
243
- program = program . clone ()
244
- global_block = program . global_block ( )
281
+ program_cache = None
282
+ program_cache_key = get_program_cache_key ( feed , fetch_list )
245
283
246
- if feed_var_name in global_block .vars :
247
- feed_var = global_block .var (feed_var_name )
284
+ if use_program_cache :
285
+ # find program cache by cache_key
286
+ program_cache = self .program_caches .get (program_cache_key , None )
287
+ # TODO(qiao): Should check program_cache and program are exactly the same.
248
288
else :
249
- feed_var = global_block .create_var (
250
- name = feed_var_name ,
251
- type = core .VarDesc .VarType .FEED_MINIBATCH ,
252
- persistable = True )
289
+ self .program_caches .pop (program_cache_key , None )
253
290
254
- if fetch_var_name in global_block .vars :
255
- fetch_var = global_block .var (fetch_var_name )
256
- else :
257
- fetch_var = global_block .create_var (
258
- name = fetch_var_name ,
259
- type = core .VarDesc .VarType .FETCH_LIST ,
260
- persistable = True )
261
-
262
- if not has_feed_operators (global_block , feed , feed_var_name ):
263
- for i , name in enumerate (feed ):
264
- out = global_block .var (name )
265
- global_block .prepend_op (
266
- type = 'feed' ,
267
- inputs = {'X' : [feed_var ]},
268
- outputs = {'Out' : [out ]},
269
- attrs = {'col' : i })
270
-
271
- for op in global_block .ops :
291
+ if program_cache is None :
292
+ program_cache = program .clone ()
293
+
294
+ if use_program_cache :
295
+ self .program_caches [program_cache_key ] = program_cache
296
+
297
+ global_block = program_cache .global_block ()
298
+
299
+ if feed_var_name in global_block .vars :
300
+ feed_var = global_block .var (feed_var_name )
301
+ else :
302
+ feed_var = global_block .create_var (
303
+ name = feed_var_name ,
304
+ type = core .VarDesc .VarType .FEED_MINIBATCH ,
305
+ persistable = True )
306
+
307
+ if fetch_var_name in global_block .vars :
308
+ fetch_var = global_block .var (fetch_var_name )
309
+ else :
310
+ fetch_var = global_block .create_var (
311
+ name = fetch_var_name ,
312
+ type = core .VarDesc .VarType .FETCH_LIST ,
313
+ persistable = True )
314
+
315
+ # prepend feed operators
316
+ if not has_feed_operators (global_block , feed , feed_var_name ):
317
+ for i , name in enumerate (feed ):
318
+ out = global_block .var (name )
319
+ global_block .prepend_op (
320
+ type = 'feed' ,
321
+ inputs = {'X' : [feed_var ]},
322
+ outputs = {'Out' : [out ]},
323
+ attrs = {'col' : i })
324
+
325
+ # append fetch_operators
326
+ if not has_fetch_operators (global_block , fetch_list ,
327
+ fetch_var_name ):
328
+ for i , var in enumerate (fetch_list ):
329
+ assert isinstance (var , Variable ) or isinstance (var , str ), (
330
+ "Wrong type for fetch_list[%s]: %s" % (i , type (var )))
331
+ global_block .append_op (
332
+ type = 'fetch' ,
333
+ inputs = {'X' : [var ]},
334
+ outputs = {'Out' : [fetch_var ]},
335
+ attrs = {'col' : i })
336
+
337
+ # feed var to framework
338
+ for op in program_cache .global_block ().ops :
272
339
if op .desc .type () == 'feed' :
273
340
feed_target_name = op .desc .output ('Out' )[0 ]
274
341
cur_feed = feed [feed_target_name ]
@@ -279,17 +346,7 @@ def run(self,
279
346
else :
280
347
break
281
348
282
- if not has_fetch_operators (global_block , fetch_list , fetch_var_name ):
283
- for i , var in enumerate (fetch_list ):
284
- assert isinstance (var , Variable ) or isinstance (var , str ), (
285
- "Wrong type for fetch_list[%s]: %s" % (i , type (var )))
286
- global_block .append_op (
287
- type = 'fetch' ,
288
- inputs = {'X' : [var ]},
289
- outputs = {'Out' : [fetch_var ]},
290
- attrs = {'col' : i })
291
-
292
- self .executor .run (program .desc , scope , 0 , True , True )
349
+ self .executor .run (program_cache .desc , scope , 0 , True , True )
293
350
outs = [
294
351
core .get_fetch_variable (scope , fetch_var_name , i )
295
352
for i in xrange (len (fetch_list ))
0 commit comments