@@ -177,6 +177,7 @@ def __init__(self, places):
177
177
# TODO(dzhwinter) : only use the first place
178
178
self .executor = core .Executor (act_places [0 ])
179
179
self .places = places
180
+ self .program_caches = dict ()
180
181
181
182
def aslodtensor (self , data ):
182
183
def accumulate (data ):
@@ -240,56 +241,63 @@ def run(self,
240
241
if scope is None :
241
242
scope = global_scope ()
242
243
243
- program = program . clone ( )
244
- global_block = program . global_block ( )
244
+ program_cache_key = str ( feed . keys () + fetch_list )
245
+ program_cache = self . program_caches . get ( program_cache_key , None )
245
246
246
- if feed_var_name in global_block .vars :
247
- feed_var = global_block .var (feed_var_name )
248
- else :
249
- feed_var = global_block .create_var (
250
- name = feed_var_name ,
251
- type = core .VarDesc .VarType .FEED_MINIBATCH ,
252
- persistable = True )
247
+ if program_cache is None :
248
+ program_cache = program .clone ()
249
+ self .program_caches [program_cache_key ] = program_cache
253
250
254
- if fetch_var_name in global_block .vars :
255
- fetch_var = global_block .var (fetch_var_name )
256
- else :
257
- fetch_var = global_block .create_var (
258
- name = fetch_var_name ,
259
- type = core .VarDesc .VarType .FETCH_LIST ,
260
- persistable = True )
261
-
262
- if not has_feed_operators (global_block , feed , feed_var_name ):
263
- for i , name in enumerate (feed ):
264
- out = global_block .var (name )
265
- global_block .prepend_op (
266
- type = 'feed' ,
267
- inputs = {'X' : [feed_var ]},
268
- outputs = {'Out' : [out ]},
269
- attrs = {'col' : i })
270
-
271
- for op in global_block .ops :
272
- if op .desc .type () == 'feed' :
273
- feed_target_name = op .desc .output ('Out' )[0 ]
274
- cur_feed = feed [feed_target_name ]
275
- if not isinstance (cur_feed , core .LoDTensor ):
276
- cur_feed = self .aslodtensor (cur_feed )
277
- idx = op .desc .attr ('col' )
278
- core .set_feed_variable (scope , cur_feed , feed_var_name , idx )
251
+ global_block = program_cache .global_block ()
252
+
253
+ if feed_var_name in global_block .vars :
254
+ feed_var = global_block .var (feed_var_name )
255
+ else :
256
+ feed_var = global_block .create_var (
257
+ name = feed_var_name ,
258
+ type = core .VarDesc .VarType .FEED_MINIBATCH ,
259
+ persistable = True )
260
+
261
+ if fetch_var_name in global_block .vars :
262
+ fetch_var = global_block .var (fetch_var_name )
279
263
else :
280
- break
281
-
282
- if not has_fetch_operators (global_block , fetch_list , fetch_var_name ):
283
- for i , var in enumerate (fetch_list ):
284
- assert isinstance (var , Variable ) or isinstance (var , str ), (
285
- "Wrong type for fetch_list[%s]: %s" % (i , type (var )))
286
- global_block .append_op (
287
- type = 'fetch' ,
288
- inputs = {'X' : [var ]},
289
- outputs = {'Out' : [fetch_var ]},
290
- attrs = {'col' : i })
291
-
292
- self .executor .run (program .desc , scope , 0 , True , True )
264
+ fetch_var = global_block .create_var (
265
+ name = fetch_var_name ,
266
+ type = core .VarDesc .VarType .FETCH_LIST ,
267
+ persistable = True )
268
+
269
+ if not has_feed_operators (global_block , feed , feed_var_name ):
270
+ for i , name in enumerate (feed ):
271
+ out = global_block .var (name )
272
+ global_block .prepend_op (
273
+ type = 'feed' ,
274
+ inputs = {'X' : [feed_var ]},
275
+ outputs = {'Out' : [out ]},
276
+ attrs = {'col' : i })
277
+
278
+ for op in global_block .ops :
279
+ if op .desc .type () == 'feed' :
280
+ feed_target_name = op .desc .output ('Out' )[0 ]
281
+ cur_feed = feed [feed_target_name ]
282
+ if not isinstance (cur_feed , core .LoDTensor ):
283
+ cur_feed = self .aslodtensor (cur_feed )
284
+ idx = op .desc .attr ('col' )
285
+ core .set_feed_variable (scope , cur_feed , feed_var_name , idx )
286
+ else :
287
+ break
288
+
289
+ if not has_fetch_operators (global_block , fetch_list ,
290
+ fetch_var_name ):
291
+ for i , var in enumerate (fetch_list ):
292
+ assert isinstance (var , Variable ) or isinstance (var , str ), (
293
+ "Wrong type for fetch_list[%s]: %s" % (i , type (var )))
294
+ global_block .append_op (
295
+ type = 'fetch' ,
296
+ inputs = {'X' : [var ]},
297
+ outputs = {'Out' : [fetch_var ]},
298
+ attrs = {'col' : i })
299
+
300
+ self .executor .run (program_cache .desc , scope , 0 , True , True )
293
301
outs = [
294
302
core .get_fetch_variable (scope , fetch_var_name , i )
295
303
for i in xrange (len (fetch_list ))
0 commit comments