@@ -235,6 +235,77 @@ def parselod(data):
235
235
tensor .set_lod (lod )
236
236
return tensor
237
237
238
+ def _get_program_cache (self , program_cache_key ):
239
+ return self .program_caches .get (program_cache_key , None )
240
+
241
+ def _add_program_cache (self , program_cache_key , program ):
242
+ self .program_caches [program_cache_key ] = program
243
+
244
+ def _add_feed_fetch_ops (self , program , feed , fetch_list , feed_var_name ,
245
+ fetch_var_name ):
246
+ tmp_program = program .clone ()
247
+
248
+ global_block = tmp_program .global_block ()
249
+
250
+ if feed_var_name in global_block .vars :
251
+ feed_var = global_block .var (feed_var_name )
252
+ else :
253
+ feed_var = global_block .create_var (
254
+ name = feed_var_name ,
255
+ type = core .VarDesc .VarType .FEED_MINIBATCH ,
256
+ persistable = True )
257
+
258
+ if fetch_var_name in global_block .vars :
259
+ fetch_var = global_block .var (fetch_var_name )
260
+ else :
261
+ fetch_var = global_block .create_var (
262
+ name = fetch_var_name ,
263
+ type = core .VarDesc .VarType .FETCH_LIST ,
264
+ persistable = True )
265
+
266
+ # prepend feed operators
267
+ if not has_feed_operators (global_block , feed , feed_var_name ):
268
+ for i , name in enumerate (feed ):
269
+ out = global_block .var (name )
270
+ global_block .prepend_op (
271
+ type = 'feed' ,
272
+ inputs = {'X' : [feed_var ]},
273
+ outputs = {'Out' : [out ]},
274
+ attrs = {'col' : i })
275
+
276
+ # append fetch_operators
277
+ if not has_fetch_operators (global_block , fetch_list , fetch_var_name ):
278
+ for i , var in enumerate (fetch_list ):
279
+ assert isinstance (var , Variable ) or isinstance (var , str ), (
280
+ "Wrong type for fetch_list[%s]: %s" % (i , type (var )))
281
+ global_block .append_op (
282
+ type = 'fetch' ,
283
+ inputs = {'X' : [var ]},
284
+ outputs = {'Out' : [fetch_var ]},
285
+ attrs = {'col' : i })
286
+
287
+ return tmp_program
288
+
289
+ def _feed_data (self , program , feed , feed_var_name , scope ):
290
+ # feed var to framework
291
+ for op in program .global_block ().ops :
292
+ if op .desc .type () == 'feed' :
293
+ feed_target_name = op .desc .output ('Out' )[0 ]
294
+ cur_feed = feed [feed_target_name ]
295
+ if not isinstance (cur_feed , core .LoDTensor ):
296
+ cur_feed = self .aslodtensor (cur_feed )
297
+ idx = op .desc .attr ('col' )
298
+ core .set_feed_variable (scope , cur_feed , feed_var_name , idx )
299
+ else :
300
+ break
301
+
302
+ def _fetch_data (self , fetch_list , fetch_var_name , scope ):
303
+ outs = [
304
+ core .get_fetch_variable (scope , fetch_var_name , i )
305
+ for i in xrange (len (fetch_list ))
306
+ ]
307
+ return outs
308
+
238
309
def run (self ,
239
310
program = None ,
240
311
feed = None ,
@@ -268,7 +339,6 @@ def run(self,
268
339
raise TypeError ("feed should be a map" )
269
340
if fetch_list is None :
270
341
fetch_list = []
271
-
272
342
if program is None :
273
343
program = default_main_program ()
274
344
@@ -278,79 +348,30 @@ def run(self,
278
348
if scope is None :
279
349
scope = global_scope ()
280
350
281
- program_cache = None
282
- program_cache_key = get_program_cache_key (feed , fetch_list )
283
-
351
+ cache_key = get_program_cache_key (feed , fetch_list )
284
352
if use_program_cache :
285
- # find program cache by cache_key
286
- program_cache = self .program_caches .get (program_cache_key , None )
287
- # TODO(qiao): Should check program_cache and program are exactly the same.
353
+ cached_program = self ._get_program_cache (cache_key )
354
+ if cached_program is None :
355
+ cached_program = self ._add_feed_fetch_ops (
356
+ program = program ,
357
+ feed = feed ,
358
+ fetch_list = fetch_list ,
359
+ feed_var_name = feed_var_name ,
360
+ fetch_var_name = fetch_var_name )
361
+ self ._add_program_cache (cache_key , cached_program )
362
+ program = cached_program
288
363
else :
289
- self .program_caches .pop (program_cache_key , None )
290
-
291
- if program_cache is None :
292
- program_cache = program .clone ()
293
-
294
- if use_program_cache :
295
- self .program_caches [program_cache_key ] = program_cache
296
-
297
- global_block = program_cache .global_block ()
298
-
299
- if feed_var_name in global_block .vars :
300
- feed_var = global_block .var (feed_var_name )
301
- else :
302
- feed_var = global_block .create_var (
303
- name = feed_var_name ,
304
- type = core .VarDesc .VarType .FEED_MINIBATCH ,
305
- persistable = True )
306
-
307
- if fetch_var_name in global_block .vars :
308
- fetch_var = global_block .var (fetch_var_name )
309
- else :
310
- fetch_var = global_block .create_var (
311
- name = fetch_var_name ,
312
- type = core .VarDesc .VarType .FETCH_LIST ,
313
- persistable = True )
314
-
315
- # prepend feed operators
316
- if not has_feed_operators (global_block , feed , feed_var_name ):
317
- for i , name in enumerate (feed ):
318
- out = global_block .var (name )
319
- global_block .prepend_op (
320
- type = 'feed' ,
321
- inputs = {'X' : [feed_var ]},
322
- outputs = {'Out' : [out ]},
323
- attrs = {'col' : i })
324
-
325
- # append fetch_operators
326
- if not has_fetch_operators (global_block , fetch_list ,
327
- fetch_var_name ):
328
- for i , var in enumerate (fetch_list ):
329
- assert isinstance (var , Variable ) or isinstance (var , str ), (
330
- "Wrong type for fetch_list[%s]: %s" % (i , type (var )))
331
- global_block .append_op (
332
- type = 'fetch' ,
333
- inputs = {'X' : [var ]},
334
- outputs = {'Out' : [fetch_var ]},
335
- attrs = {'col' : i })
336
-
337
- # feed var to framework
338
- for op in program_cache .global_block ().ops :
339
- if op .desc .type () == 'feed' :
340
- feed_target_name = op .desc .output ('Out' )[0 ]
341
- cur_feed = feed [feed_target_name ]
342
- if not isinstance (cur_feed , core .LoDTensor ):
343
- cur_feed = self .aslodtensor (cur_feed )
344
- idx = op .desc .attr ('col' )
345
- core .set_feed_variable (scope , cur_feed , feed_var_name , idx )
346
- else :
347
- break
348
-
349
- self .executor .run (program_cache .desc , scope , 0 , True , True )
350
- outs = [
351
- core .get_fetch_variable (scope , fetch_var_name , i )
352
- for i in xrange (len (fetch_list ))
353
- ]
364
+ self .program_caches .pop (cache_key , None )
365
+ program = self ._add_feed_fetch_ops (
366
+ program = program ,
367
+ feed = feed ,
368
+ fetch_list = fetch_list ,
369
+ feed_var_name = feed_var_name ,
370
+ fetch_var_name = fetch_var_name )
371
+
372
+ self ._feed_data (program , feed , feed_var_name , scope )
373
+ self .executor .run (program .desc , scope , 0 , True , True )
374
+ outs = self ._fetch_data (fetch_list , fetch_var_name , scope )
354
375
if return_numpy :
355
376
outs = as_numpy (outs )
356
377
return outs
0 commit comments