Skip to content

Commit 445ca3d

Browse files
authored
Merge pull request #12607 from jacquesqiao/add-unit-test-for-async-transpile
Add unit test for async transpile
2 parents 7555cfe + 2ae32f0 commit 445ca3d

File tree

2 files changed

+87
-14
lines changed

2 files changed

+87
-14
lines changed

python/paddle/fluid/tests/unittests/test_dist_transpiler.py

Lines changed: 78 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,25 +51,26 @@ def get_main_program(self):
5151
self.origin_prog = main.clone()
5252
return main
5353

54-
def get_trainer(self, config=None):
55-
t = self._transpiler_instance(config)
54+
def get_trainer(self, config=None, sync_mode=True):
55+
t = self._transpiler_instance(config, sync_mode)
5656
return t.get_trainer_program()
5757

58-
def get_pserver(self, ep, config=None):
59-
t = self._transpiler_instance(config)
58+
def get_pserver(self, ep, config=None, sync_mode=True):
59+
t = self._transpiler_instance(config, sync_mode)
6060
pserver = t.get_pserver_program(ep)
6161
startup = t.get_startup_program(ep, pserver)
6262
return pserver, startup
6363

64-
def _transpiler_instance(self, config=None):
64+
def _transpiler_instance(self, config=None, sync_mode=True):
6565
if not self.transpiler:
6666
main = self.get_main_program()
6767
self.transpiler = fluid.DistributeTranspiler(config=config)
6868
self.transpiler.transpile(
6969
self.trainer_id,
7070
program=main,
7171
pservers=self.pserver_eps,
72-
trainers=self.trainers)
72+
trainers=self.trainers,
73+
sync_mode=sync_mode)
7374

7475
return self.transpiler
7576

@@ -464,5 +465,76 @@ def transpiler_test_impl(self):
464465
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
465466

466467

468+
class TestAsyncLocalLookupTable(TestDistLookupTableBase):
469+
def net_conf(self):
470+
self.network_with_table(is_sparse=True, is_distributed=False)
471+
472+
def transpiler_test_impl(self):
473+
config = fluid.DistributeTranspilerConfig()
474+
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
475+
476+
self.assertEqual(len(pserver1.blocks), 3)
477+
# 0 listen_and_serv
478+
# 1 optimize for fc_w or fc_b adam
479+
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
480+
["adam", "scale", "scale"])
481+
# 2 optimize for table adam
482+
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
483+
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
484+
["adam", "scale", "scale"])
485+
486+
trainer = self.get_trainer(config)
487+
self.assertEqual(len(trainer.blocks), 1)
488+
ops = [
489+
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
490+
'concat', 'mul', 'elementwise_add', 'cross_entropy', 'mean',
491+
'fill_constant', 'mean_grad', 'cross_entropy_grad',
492+
'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad',
493+
'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
494+
'lookup_table_grad', 'sum', 'split_selected_rows', 'send', 'recv',
495+
'recv', 'recv', 'concat'
496+
]
497+
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
498+
499+
500+
class TestAsyncDistLookupTable(TestDistLookupTableBase):
501+
def net_conf(self):
502+
self.network_with_table(is_sparse=True, is_distributed=True)
503+
504+
def transpiler_test_impl(self):
505+
config = fluid.DistributeTranspilerConfig()
506+
507+
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
508+
509+
self.assertEqual(len(pserver1.blocks), 6)
510+
# 0 listen_and_serv
511+
# 1 optimize for fc_w or fc_b adam
512+
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
513+
["adam", "scale", "scale"])
514+
# 2 optimize for table sgd
515+
self.assertEqual([op.type for op in pserver1.blocks[2].ops], ["sgd"])
516+
# 3 prefetch -> lookup_sparse_table for data0
517+
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
518+
["lookup_sparse_table"])
519+
# 4 prefetch -> lookup_sparse_table for data1
520+
self.assertEqual([op.type for op in pserver1.blocks[4].ops],
521+
["lookup_sparse_table"])
522+
# 5 save table
523+
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
524+
525+
trainer = self.get_trainer(config)
526+
self.assertEqual(len(trainer.blocks), 1)
527+
ops = [
528+
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids',
529+
'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul',
530+
'elementwise_add', 'cross_entropy', 'mean', 'fill_constant',
531+
'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send',
532+
'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad',
533+
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
534+
'sum', 'split_ids', 'send', 'recv', 'recv'
535+
]
536+
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
537+
538+
467539
if __name__ == "__main__":
468540
unittest.main()

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -293,14 +293,15 @@ def transpile(self,
293293
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
294294
})
295295

296-
program.global_block().append_op(
297-
type="fetch_barrier",
298-
inputs={},
299-
outputs={},
300-
attrs={
301-
"endpoints": pserver_endpoints,
302-
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
303-
})
296+
if self.sync_mode:
297+
program.global_block().append_op(
298+
type="fetch_barrier",
299+
inputs={},
300+
outputs={},
301+
attrs={
302+
"endpoints": pserver_endpoints,
303+
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
304+
})
304305

305306
for varname, splited_var in self.param_var_mapping.iteritems():
306307
if len(splited_var) <= 1:

0 commit comments

Comments
 (0)