@@ -51,25 +51,26 @@ def get_main_program(self):
5151 self .origin_prog = main .clone ()
5252 return main
5353
54- def get_trainer (self , config = None ):
55- t = self ._transpiler_instance (config )
54+ def get_trainer (self , config = None , sync_mode = True ):
55+ t = self ._transpiler_instance (config , sync_mode )
5656 return t .get_trainer_program ()
5757
58- def get_pserver (self , ep , config = None ):
59- t = self ._transpiler_instance (config )
58+ def get_pserver (self , ep , config = None , sync_mode = True ):
59+ t = self ._transpiler_instance (config , sync_mode )
6060 pserver = t .get_pserver_program (ep )
6161 startup = t .get_startup_program (ep , pserver )
6262 return pserver , startup
6363
64- def _transpiler_instance (self , config = None ):
64+ def _transpiler_instance (self , config = None , sync_mode = True ):
6565 if not self .transpiler :
6666 main = self .get_main_program ()
6767 self .transpiler = fluid .DistributeTranspiler (config = config )
6868 self .transpiler .transpile (
6969 self .trainer_id ,
7070 program = main ,
7171 pservers = self .pserver_eps ,
72- trainers = self .trainers )
72+ trainers = self .trainers ,
73+ sync_mode = sync_mode )
7374
7475 return self .transpiler
7576
@@ -470,8 +471,7 @@ def net_conf(self):
470471
471472 def transpiler_test_impl (self ):
472473 config = fluid .DistributeTranspilerConfig ()
473- config .sync_mode = False
474- pserver1 , startup1 = self .get_pserver (self .pserver1_ep , config )
474+ pserver1 , startup1 = self .get_pserver (self .pserver1_ep , config , False )
475475
476476 self .assertEqual (len (pserver1 .blocks ), 3 )
477477 # 0 listen_and_serv
@@ -503,9 +503,8 @@ def net_conf(self):
503503
504504 def transpiler_test_impl (self ):
505505 config = fluid .DistributeTranspilerConfig ()
506- config .sync_mode = False
507506
508- pserver1 , startup1 = self .get_pserver (self .pserver1_ep , config )
507+ pserver1 , startup1 = self .get_pserver (self .pserver1_ep , config , False )
509508
510509 self .assertEqual (len (pserver1 .blocks ), 6 )
511510 # 0 listen_and_serv
@@ -525,7 +524,6 @@ def transpiler_test_impl(self):
525524
526525 trainer = self .get_trainer (config )
527526 self .assertEqual (len (trainer .blocks ), 1 )
528- print ([op .type for op in trainer .blocks [0 ].ops ])
529527 ops = [
530528 'split_ids' , 'prefetch' , 'merge_ids' , 'sequence_pool' , 'split_ids' ,
531529 'prefetch' , 'merge_ids' , 'sequence_pool' , 'concat' , 'mul' ,
0 commit comments