@@ -305,7 +305,9 @@ class Executor(object):
305
305
def __init__ (self , place ):
306
306
self .place = place
307
307
self .program_caches = dict ()
308
- self .executor = None
308
+ p = core .Place ()
309
+ p .set_place (self .place )
310
+ self ._default_executor = core .Executor (p )
309
311
self ._closed = False
310
312
311
313
def _get_program_cache (self , program_cache_key ):
@@ -397,12 +399,13 @@ def close(self):
397
399
>>> ...
398
400
>>> exe.close()
399
401
"""
400
- if not self ._closed and self . executor :
401
- self .executor .close ()
402
+ if not self ._closed :
403
+ self ._default_executor .close ()
402
404
self ._closed = True
403
405
404
406
def _run_parallel (self , program , scope , feed , fetch_list , fetch_var_name ,
405
407
return_numpy ):
408
+ exe = program ._executor
406
409
if isinstance (feed , dict ):
407
410
feed_tensor_dict = dict ()
408
411
for feed_name in feed :
@@ -414,8 +417,7 @@ def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name,
414
417
feed_tensor .set (feed [feed_name ], core .CPUPlace ())
415
418
feed_tensor_dict [feed_name ] = feed_tensor
416
419
417
- self .executor .feed_and_split_tensor_into_local_scopes (
418
- feed_tensor_dict )
420
+ exe .feed_and_split_tensor_into_local_scopes (feed_tensor_dict )
419
421
elif isinstance (feed , list ) or isinstance (feed , tuple ):
420
422
if len (feed ) != len (program ._places ):
421
423
raise ValueError (
@@ -436,10 +438,10 @@ def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name,
436
438
tensor = tmp
437
439
res_dict [feed_name ] = tensor
438
440
res .append (res_dict )
439
- self . executor .feed_tensors_into_local_scopes (res )
441
+ exe .feed_tensors_into_local_scopes (res )
440
442
441
443
fetch_var_names = list (map (_to_name_str , fetch_list ))
442
- self . executor .run (fetch_var_names , fetch_var_name )
444
+ exe .run (fetch_var_names , fetch_var_name )
443
445
arr = scope .find_var (fetch_var_name ).get_lod_tensor_array ()
444
446
445
447
if return_numpy :
@@ -511,12 +513,9 @@ def run(self,
511
513
compiled = isinstance (program , compiler .CompiledProgram )
512
514
# For backward compatibility, run directly.
513
515
if not compiled :
514
- if not self .executor :
515
- p = core .Place ()
516
- p .set_place (self .place )
517
- self .executor = core .Executor (p )
518
516
return self ._run (
519
517
program ,
518
+ self ._default_executor ,
520
519
feed = feed ,
521
520
fetch_list = fetch_list ,
522
521
feed_var_name = feed_var_name ,
@@ -526,7 +525,6 @@ def run(self,
526
525
use_program_cache = use_program_cache )
527
526
528
527
program ._compile (scope , self .place )
529
- self .executor = program ._executor
530
528
if program ._is_data_parallel :
531
529
return self ._run_parallel (
532
530
program ,
@@ -536,12 +534,13 @@ def run(self,
536
534
fetch_var_name = fetch_var_name ,
537
535
return_numpy = return_numpy )
538
536
elif program ._is_inference :
539
- return self ._run_inference (program , feed )
537
+ return self ._run_inference (program . _executor , feed )
540
538
else :
541
539
# TODO(panyx0718): Can compile program to optimize executor
542
540
# performance.
543
541
return self ._run (
544
542
program ._program ,
543
+ self ._default_executor ,
545
544
feed = feed ,
546
545
fetch_list = fetch_list ,
547
546
feed_var_name = feed_var_name ,
@@ -550,8 +549,8 @@ def run(self,
550
549
return_numpy = return_numpy ,
551
550
use_program_cache = use_program_cache )
552
551
553
- def _run (self , program , feed , fetch_list , feed_var_name , fetch_var_name ,
554
- scope , return_numpy , use_program_cache ):
552
+ def _run (self , program , exe , feed , fetch_list , feed_var_name ,
553
+ fetch_var_name , scope , return_numpy , use_program_cache ):
555
554
556
555
if feed is None :
557
556
feed = {}
@@ -589,11 +588,11 @@ def _run(self, program, feed, fetch_list, feed_var_name, fetch_var_name,
589
588
fetch_var_name = fetch_var_name )
590
589
591
590
self ._feed_data (program , feed , feed_var_name , scope )
592
- self . executor .run (program .desc , scope , 0 , True , True )
591
+ exe .run (program .desc , scope , 0 , True , True )
593
592
outs = self ._fetch_data (fetch_list , fetch_var_name , scope )
594
593
if return_numpy :
595
594
outs = as_numpy (outs )
596
595
return outs
597
596
598
- def _run_inference (self , program , feed ):
599
- return self . executor .run (feed )
597
+ def _run_inference (self , exe , feed ):
598
+ return exe .run (feed )
0 commit comments