Skip to content

Commit 2ae32f0

Browse files
committed
revert the change of api
1 parent 1b69021 commit 2ae32f0

File tree

2 files changed

+10
-13
lines changed

2 files changed

+10
-13
lines changed

python/paddle/fluid/tests/unittests/test_dist_transpiler.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,25 +51,26 @@ def get_main_program(self):
5151
self.origin_prog = main.clone()
5252
return main
5353

54-
def get_trainer(self, config=None):
55-
t = self._transpiler_instance(config)
54+
def get_trainer(self, config=None, sync_mode=True):
55+
t = self._transpiler_instance(config, sync_mode)
5656
return t.get_trainer_program()
5757

58-
def get_pserver(self, ep, config=None):
59-
t = self._transpiler_instance(config)
58+
def get_pserver(self, ep, config=None, sync_mode=True):
59+
t = self._transpiler_instance(config, sync_mode)
6060
pserver = t.get_pserver_program(ep)
6161
startup = t.get_startup_program(ep, pserver)
6262
return pserver, startup
6363

64-
def _transpiler_instance(self, config=None):
64+
def _transpiler_instance(self, config=None, sync_mode=True):
6565
if not self.transpiler:
6666
main = self.get_main_program()
6767
self.transpiler = fluid.DistributeTranspiler(config=config)
6868
self.transpiler.transpile(
6969
self.trainer_id,
7070
program=main,
7171
pservers=self.pserver_eps,
72-
trainers=self.trainers)
72+
trainers=self.trainers,
73+
sync_mode=sync_mode)
7374

7475
return self.transpiler
7576

@@ -470,8 +471,7 @@ def net_conf(self):
470471

471472
def transpiler_test_impl(self):
472473
config = fluid.DistributeTranspilerConfig()
473-
config.sync_mode = False
474-
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config)
474+
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
475475

476476
self.assertEqual(len(pserver1.blocks), 3)
477477
# 0 listen_and_serv
@@ -503,9 +503,8 @@ def net_conf(self):
503503

504504
def transpiler_test_impl(self):
505505
config = fluid.DistributeTranspilerConfig()
506-
config.sync_mode = False
507506

508-
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config)
507+
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
509508

510509
self.assertEqual(len(pserver1.blocks), 6)
511510
# 0 listen_and_serv
@@ -525,7 +524,6 @@ def transpiler_test_impl(self):
525524

526525
trainer = self.get_trainer(config)
527526
self.assertEqual(len(trainer.blocks), 1)
528-
print([op.type for op in trainer.blocks[0].ops])
529527
ops = [
530528
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids',
531529
'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul',

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ class DistributeTranspilerConfig(object):
124124
slice_var_up = True
125125
split_method = None
126126
min_block_size = 8192
127-
sync_mode = True
128127

129128

130129
class DistributeTranspiler(object):
@@ -198,7 +197,7 @@ def transpile(self,
198197
program = default_main_program()
199198
self.origin_program = program
200199
self.trainer_num = trainers
201-
self.sync_mode = sync_mode and self.config.sync_mode
200+
self.sync_mode = sync_mode
202201
self.trainer_id = trainer_id
203202
pserver_endpoints = pservers.split(",")
204203
self.pserver_endpoints = pserver_endpoints

0 commit comments

Comments
 (0)