Skip to content

Commit 1b69021

Browse files
committed
add TestAsyncLocalLookupTable
1 parent 8da6510 commit 1b69021

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

python/paddle/fluid/tests/unittests/test_dist_transpiler.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,39 @@ def transpiler_test_impl(self):
464464
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
465465

466466

467+
class TestAsyncLocalLookupTable(TestDistLookupTableBase):
468+
def net_conf(self):
469+
self.network_with_table(is_sparse=True, is_distributed=False)
470+
471+
def transpiler_test_impl(self):
472+
config = fluid.DistributeTranspilerConfig()
473+
config.sync_mode = False
474+
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config)
475+
476+
self.assertEqual(len(pserver1.blocks), 3)
477+
# 0 listen_and_serv
478+
# 1 optimize for fc_w or fc_b adam
479+
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
480+
["adam", "scale", "scale"])
481+
# 2 optimize for table adam
482+
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
483+
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
484+
["adam", "scale", "scale"])
485+
486+
trainer = self.get_trainer(config)
487+
self.assertEqual(len(trainer.blocks), 1)
488+
ops = [
489+
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
490+
'concat', 'mul', 'elementwise_add', 'cross_entropy', 'mean',
491+
'fill_constant', 'mean_grad', 'cross_entropy_grad',
492+
'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad',
493+
'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
494+
'lookup_table_grad', 'sum', 'split_selected_rows', 'send', 'recv',
495+
'recv', 'recv', 'concat'
496+
]
497+
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
498+
499+
467500
class TestAsyncDistLookupTable(TestDistLookupTableBase):
468501
def net_conf(self):
469502
self.network_with_table(is_sparse=True, is_distributed=True)

0 commit comments

Comments
 (0)