Skip to content

Commit add4b46

Browse files
committed
dist table only handle is_distributed table
1 parent d186e74 commit add4b46

File tree

2 files changed

+68
-43
lines changed

2 files changed

+68
-43
lines changed

python/paddle/fluid/tests/unittests/test_dist_transpiler.py

Lines changed: 64 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -411,12 +411,12 @@ def network_with_table(self, is_sparse, is_distributed):
411411
self.emb_size = 64
412412
self.lookup_table_name = 'shared_w'
413413

414-
def emb_pool(ids):
414+
def emb_pool(ids, table_name, is_distributed):
415415
emb = fluid.layers.embedding(
416416
input=ids,
417417
size=[self.table_size, self.emb_size],
418418
dtype='float32',
419-
param_attr=self.lookup_table_name, # share parameter
419+
param_attr=table_name,
420420
is_sparse=is_sparse,
421421
is_distributed=is_distributed)
422422
pool = fluid.layers.sequence_pool(input=emb, pool_type='average')
@@ -426,9 +426,12 @@ def emb_pool(ids):
426426
name='title_ids', shape=[1], dtype='int64', lod_level=1)
427427
brand_ids = fluid.layers.data(
428428
name='brand_ids', shape=[1], dtype='int64', lod_level=1)
429-
title_emb = emb_pool(title_ids)
430-
brand_emb = emb_pool(brand_ids)
431-
fc0 = fluid.layers.concat(input=[title_emb, brand_emb], axis=1)
429+
profile_ids = fluid.layers.data(
430+
name='brand_ids', shape=[1], dtype='int64', lod_level=1)
431+
title_emb = emb_pool(title_ids, self.lookup_table_name, is_distributed)
432+
brand_emb = emb_pool(brand_ids, self.lookup_table_name, is_distributed)
433+
profile_emb = emb_pool(profile_ids, "profile_emb", False)
434+
fc0 = fluid.layers.concat(input=[title_emb, brand_emb, profile_emb], axis=1)
432435
predict = fluid.layers.fc(input=fc0,
433436
size=2,
434437
act=None,
@@ -449,7 +452,7 @@ def net_conf(self):
449452
def transpiler_test_impl(self):
450453
pserver1, startup1 = self.get_pserver(self.pserver1_ep)
451454

452-
self.assertEqual(len(pserver1.blocks), 3)
455+
self.assertEqual(len(pserver1.blocks), 4)
453456
# 0 listen_and_serv
454457
# 1 optimize for fc_w or fc_b adam
455458
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
@@ -459,16 +462,22 @@ def transpiler_test_impl(self):
459462
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
460463
["sum", "scale", "adam", "scale", "scale"])
461464

465+
# 3 optimize for table 2 adam
466+
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
467+
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
468+
["sum", "scale", "adam", "scale", "scale"])
469+
462470
trainer, _ = self.get_trainer()
463471
self.assertEqual(len(trainer.blocks), 1)
464472
ops = [
465473
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
466-
'concat', 'mul', 'elementwise_add', 'cross_entropy', 'mean',
474+
'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add', 'cross_entropy', 'mean',
467475
'fill_constant', 'mean_grad', 'cross_entropy_grad',
468476
'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad',
477+
'sequence_pool_grad', 'lookup_table_grad', 'split_selected_rows', 'send',
469478
'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
470479
'lookup_table_grad', 'sum', 'split_selected_rows', 'send',
471-
'send_barrier', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat'
480+
'send_barrier', 'recv', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat', 'concat'
472481
]
473482
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
474483

@@ -480,40 +489,42 @@ def net_conf(self):
480489
def transpiler_test_impl(self):
481490
pserver1, startup1 = self.get_pserver(self.pserver1_ep)
482491

483-
self.assertEqual(len(pserver1.blocks), 5)
492+
self.assertEqual(len(pserver1.blocks), 6)
484493
# 0 listen_and_serv
485494
# 1 optimize for fc_w or fc_b adam
486495
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
487496
["sum", "scale", "adam", "scale", "scale"])
488-
# 2 optimize for table sgd
497+
# 4 prefetch -> lookup_sparse_table for data0
489498
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
499+
["sum", "scale", "adam", "scale", "scale"])
500+
# 2 optimize for table sgd
501+
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
490502
["sum", "sgd"])
491503
# 3 prefetch -> lookup_sparse_table for data0
492-
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
504+
self.assertEqual([op.type for op in pserver1.blocks[4].ops],
493505
["lookup_sparse_table"])
494-
# 4 save table
495-
self.assertEqual([op.type for op in pserver1.blocks[4].ops], ["save"])
506+
# 5 save table
507+
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
496508

497509
trainer, trainer_startup = self.get_trainer()
498510
self.assertEqual(len(trainer.blocks), 1)
499511
ops = [
500512
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
501-
'sequence_pool', 'concat', 'mul', 'elementwise_add',
513+
'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add',
502514
'cross_entropy', 'mean', 'fill_constant', 'mean_grad',
503515
'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad',
504516
'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
505-
'sequence_pool_grad', 'lookup_table_grad', 'sum', 'split_ids',
506-
'send', 'send_barrier', 'recv', 'recv', 'fetch_barrier'
507-
]
517+
'split_selected_rows', 'send', 'sequence_pool_grad', 'lookup_table_grad',
518+
'sequence_pool_grad', 'lookup_table_grad', 'sum', 'split_ids', 'send', 'send_barrier',
519+
'recv', 'recv', 'recv', 'fetch_barrier', 'concat']
508520
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
509-
510521
startup_ops = [
511522
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
512523
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
513524
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
514-
'fill_constant', 'fill_constant', 'uniform_random', 'recv', 'recv',
515-
'fetch_barrier', 'fake_init'
516-
]
525+
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
526+
'fill_constant', 'fill_constant', 'uniform_random', 'uniform_random',
527+
'recv', 'recv', 'recv', 'fetch_barrier', 'concat', 'fake_init']
517528
self.assertEqual([op.type for op in trainer_startup.blocks[0].ops],
518529
startup_ops)
519530

@@ -526,7 +537,7 @@ def transpiler_test_impl(self):
526537
config = fluid.DistributeTranspilerConfig()
527538
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
528539

529-
self.assertEqual(len(pserver1.blocks), 3)
540+
self.assertEqual(len(pserver1.blocks), 4)
530541
# 0 listen_and_serv
531542
# 1 optimize for fc_w or fc_b adam
532543
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
@@ -535,17 +546,24 @@ def transpiler_test_impl(self):
535546
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
536547
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
537548
["adam", "scale", "scale"])
549+
# 3 optimize for table adam
550+
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
551+
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
552+
["adam", "scale", "scale"])
538553

539554
trainer, _ = self.get_trainer(config)
540555
self.assertEqual(len(trainer.blocks), 1)
541556
ops = [
542557
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
543-
'concat', 'mul', 'elementwise_add', 'cross_entropy', 'mean',
544-
'fill_constant', 'mean_grad', 'cross_entropy_grad',
545-
'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad',
546-
'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
547-
'lookup_table_grad', 'sum', 'split_selected_rows', 'send', 'recv',
548-
'recv', 'recv', 'concat'
558+
'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add',
559+
'cross_entropy', 'mean', 'fill_constant', 'mean_grad',
560+
'cross_entropy_grad', 'elementwise_add_grad', 'send',
561+
'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad',
562+
'lookup_table_grad', 'split_selected_rows', 'send',
563+
'sequence_pool_grad', 'lookup_table_grad',
564+
'sequence_pool_grad', 'lookup_table_grad',
565+
'sum', 'split_selected_rows', 'send', 'recv', 'recv', 'recv', 'recv',
566+
'concat', 'concat'
549567
]
550568
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
551569

@@ -559,30 +577,34 @@ def transpiler_test_impl(self):
559577

560578
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
561579

562-
self.assertEqual(len(pserver1.blocks), 5)
580+
self.assertEqual(len(pserver1.blocks), 6)
563581
# 0 listen_and_serv
564582
# 1 optimize for fc_w or fc_b adam
565583
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
566584
["adam", "scale", "scale"])
567-
# 2 optimize for table sgd
568-
self.assertEqual([op.type for op in pserver1.blocks[2].ops], ["sgd"])
569-
# 3 prefetch -> lookup_sparse_table for data0
570-
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
585+
# 2 optimize for table adam
586+
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
587+
["adam", "scale", "scale"])
588+
# 3 optimize for table sgd
589+
self.assertEqual([op.type for op in pserver1.blocks[3].ops], ["sgd"])
590+
# 4 prefetch -> lookup_sparse_table for data0
591+
self.assertEqual([op.type for op in pserver1.blocks[4].ops],
571592
["lookup_sparse_table"])
572-
# 4 save table
573-
self.assertEqual([op.type for op in pserver1.blocks[4].ops], ["save"])
593+
# 5 save table
594+
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
574595

575596
trainer, _ = self.get_trainer(config)
576597
self.assertEqual(len(trainer.blocks), 1)
577598
ops = [
578599
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
579-
'sequence_pool', 'concat', 'mul', 'elementwise_add',
580-
'cross_entropy', 'mean', 'fill_constant', 'mean_grad',
581-
'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad',
582-
'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
583-
'sequence_pool_grad', 'lookup_table_grad', 'sum', 'split_ids',
584-
'send', 'recv', 'recv'
585-
]
600+
'sequence_pool', 'lookup_table', 'sequence_pool',
601+
'concat', 'mul', 'elementwise_add', 'cross_entropy',
602+
'mean', 'fill_constant', 'mean_grad', 'cross_entropy_grad',
603+
'elementwise_add_grad', 'send', 'mul_grad', 'send',
604+
'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
605+
'split_selected_rows', 'send', 'sequence_pool_grad',
606+
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
607+
'sum', 'split_ids', 'send', 'recv', 'recv', 'recv', 'concat']
586608
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
587609

588610

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1065,7 +1065,10 @@ def _replace_lookup_table_op_with_prefetch(self, program,
10651065
continue_search_lookup_table_op = False
10661066
all_ops = program.global_block().ops
10671067
for op in all_ops:
1068-
if op.type == LOOKUP_TABLE_TYPE:
1068+
if op.type == LOOKUP_TABLE_TYPE and self.table_name == op.input("W")[0]:
1069+
if not op.attr('is_distributed'):
1070+
raise RuntimeError("lookup_table_op that lookup an distributed embedding table"
1071+
"should set is_distributed to true")
10691072
continue_search_lookup_table_op = True
10701073

10711074
lookup_table_op_index = lookup_table_op_index if lookup_table_op_index != -1 else list(

0 commit comments

Comments
 (0)