Skip to content

Commit c0e8dd8

Browse files
committed
add unit test for dist lookup table
1 parent fd53fdf commit c0e8dd8

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

python/paddle/fluid/tests/unittests/test_dist_transpiler.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,5 +359,79 @@ def transpiler_test_impl(self):
359359
["sum", "scale", "scale", "elementwise_add", "momentum"])
360360

361361

362+
class TestDistLookupTableBase(TranspilerTest):
363+
def network_with_table(self, is_sparse, is_distributed):
364+
def emb_pool(ids):
365+
table_size = 1000
366+
emb_size = 64
367+
emb = fluid.layers.embedding(
368+
input=ids,
369+
size=[table_size, emb_size],
370+
dtype='float32',
371+
param_attr='shared_w', # share parameter
372+
is_sparse=is_sparse,
373+
is_distributed=is_distributed)
374+
pool = fluid.layers.sequence_pool(input=emb, pool_type='average')
375+
return pool
376+
377+
title_ids = fluid.layers.data(
378+
name='title_ids', shape=[1], dtype='int64', lod_level=1)
379+
brand_ids = fluid.layers.data(
380+
name='brand_ids', shape=[1], dtype='int64', lod_level=1)
381+
title_emb = emb_pool(title_ids)
382+
brand_emb = emb_pool(brand_ids)
383+
fc0 = fluid.layers.concat(input=[title_emb, brand_emb], axis=1)
384+
predict = fluid.layers.fc(input=fc0,
385+
size=2,
386+
act=None,
387+
param_attr=fluid.ParamAttr(name='fc_w'),
388+
bias_attr=fluid.ParamAttr(name='fc_b'))
389+
390+
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
391+
cost = fluid.layers.cross_entropy(input=predict, label=label)
392+
avg_cost = fluid.layers.mean(cost)
393+
optimizer = fluid.optimizer.Adam(learning_rate=0.003)
394+
optimizer.minimize(avg_cost)
395+
396+
397+
class TestDistLookupTable(TestDistLookupTableBase):
398+
def net_conf(self):
399+
self.network_with_table(is_sparse=True, is_distributed=True)
400+
401+
def transpiler_test_impl(self):
402+
pserver1, startup1 = self.get_pserver(self.pserver1_ep)
403+
404+
self.assertEqual(len(pserver1.blocks), 6)
405+
# 0 listen_and_serv
406+
# 1 optimize for fc_w or fc_b adam
407+
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
408+
["sum", "scale", "adam", "scale", "scale"])
409+
# 2 optimize for table sgd
410+
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
411+
["sum", "sgd"])
412+
# 3 prefetch -> lookup_sparse_table for data0
413+
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
414+
["lookup_sparse_table"])
415+
# 4 prefetch -> lookup_sparse_table for data1
416+
self.assertEqual([op.type for op in pserver1.blocks[4].ops],
417+
["lookup_sparse_table"])
418+
# 5 save table
419+
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
420+
421+
trainer = self.get_trainer()
422+
self.assertEqual(len(trainer.blocks), 1)
423+
ops = [
424+
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids',
425+
'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul',
426+
'elementwise_add', 'cross_entropy', 'mean', 'fill_constant',
427+
'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send',
428+
'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad',
429+
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
430+
'sum', 'split_ids', 'send', 'send_barrier', 'recv', 'recv',
431+
'fetch_barrier'
432+
]
433+
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
434+
435+
362436
if __name__ == "__main__":
363437
unittest.main()

0 commit comments

Comments
 (0)