Skip to content

Commit 5a608fc

Browse files
committed
add TestDistLookupTable
1 parent c0e8dd8 commit 5a608fc

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

python/paddle/fluid/tests/unittests/test_dist_transpiler.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,38 @@ def emb_pool(ids):
394394
optimizer.minimize(avg_cost)
395395

396396

397+
class TestLocalLookupTable(TestDistLookupTableBase):
398+
def net_conf(self):
399+
self.network_with_table(is_sparse=True, is_distributed=False)
400+
401+
def transpiler_test_impl(self):
402+
pserver1, startup1 = self.get_pserver(self.pserver1_ep)
403+
404+
self.assertEqual(len(pserver1.blocks), 3)
405+
# print(str(pserver1))
406+
# 0 listen_and_serv
407+
# 1 optimize for fc_w or fc_b adam
408+
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
409+
["sum", "scale", "adam", "scale", "scale"])
410+
# 2 optimize for table adam
411+
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
412+
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
413+
["sum", "adam", "scale", "scale"])
414+
415+
trainer = self.get_trainer()
416+
self.assertEqual(len(trainer.blocks), 1)
417+
ops = [
418+
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
419+
'concat', 'mul', 'elementwise_add', 'cross_entropy', 'mean',
420+
'fill_constant', 'mean_grad', 'cross_entropy_grad',
421+
'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad',
422+
'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
423+
'lookup_table_grad', 'sum', 'split_selected_rows', 'send',
424+
'send_barrier', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat'
425+
]
426+
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
427+
428+
397429
class TestDistLookupTable(TestDistLookupTableBase):
398430
def net_conf(self):
399431
self.network_with_table(is_sparse=True, is_distributed=True)

0 commit comments

Comments
 (0)