Skip to content

Commit 8da6510

Browse files
committed
add TestAsyncDistLookupTable
1 parent 66be532 commit 8da6510

File tree

2 files changed

+52
-9
lines changed

2 files changed

+52
-9
lines changed

python/paddle/fluid/tests/unittests/test_dist_transpiler.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,5 +464,46 @@ def transpiler_test_impl(self):
464464
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
465465

466466

467+
class TestAsyncDistLookupTable(TestDistLookupTableBase):
468+
def net_conf(self):
469+
self.network_with_table(is_sparse=True, is_distributed=True)
470+
471+
def transpiler_test_impl(self):
472+
config = fluid.DistributeTranspilerConfig()
473+
config.sync_mode = False
474+
475+
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config)
476+
477+
self.assertEqual(len(pserver1.blocks), 6)
478+
# 0 listen_and_serv
479+
# 1 optimize for fc_w or fc_b adam
480+
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
481+
["adam", "scale", "scale"])
482+
# 2 optimize for table sgd
483+
self.assertEqual([op.type for op in pserver1.blocks[2].ops], ["sgd"])
484+
# 3 prefetch -> lookup_sparse_table for data0
485+
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
486+
["lookup_sparse_table"])
487+
# 4 prefetch -> lookup_sparse_table for data1
488+
self.assertEqual([op.type for op in pserver1.blocks[4].ops],
489+
["lookup_sparse_table"])
490+
# 5 save table
491+
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
492+
493+
trainer = self.get_trainer(config)
494+
self.assertEqual(len(trainer.blocks), 1)
495+
print([op.type for op in trainer.blocks[0].ops])
496+
ops = [
497+
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids',
498+
'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul',
499+
'elementwise_add', 'cross_entropy', 'mean', 'fill_constant',
500+
'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send',
501+
'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad',
502+
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
503+
'sum', 'split_ids', 'send', 'recv', 'recv'
504+
]
505+
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
506+
507+
467508
if __name__ == "__main__":
468509
unittest.main()

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ class DistributeTranspilerConfig(object):
124124
slice_var_up = True
125125
split_method = None
126126
min_block_size = 8192
127+
sync_mode = True
127128

128129

129130
class DistributeTranspiler(object):
@@ -197,7 +198,7 @@ def transpile(self,
197198
program = default_main_program()
198199
self.origin_program = program
199200
self.trainer_num = trainers
200-
self.sync_mode = sync_mode
201+
self.sync_mode = sync_mode and self.config.sync_mode
201202
self.trainer_id = trainer_id
202203
pserver_endpoints = pservers.split(",")
203204
self.pserver_endpoints = pserver_endpoints
@@ -293,14 +294,15 @@ def transpile(self,
293294
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
294295
})
295296

296-
program.global_block().append_op(
297-
type="fetch_barrier",
298-
inputs={},
299-
outputs={},
300-
attrs={
301-
"endpoints": pserver_endpoints,
302-
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
303-
})
297+
if self.sync_mode:
298+
program.global_block().append_op(
299+
type="fetch_barrier",
300+
inputs={},
301+
outputs={},
302+
attrs={
303+
"endpoints": pserver_endpoints,
304+
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
305+
})
304306

305307
for varname, splited_var in self.param_var_mapping.iteritems():
306308
if len(splited_var) <= 1:

0 commit comments

Comments
 (0)