Skip to content

Commit 7825ae9

Browse files
authored
Merge pull request #14190 from jacquesqiao/dist-table-support-multi-table
Dist table support multi table
2 parents 2ccf77d + f3bbd3b commit 7825ae9

File tree

2 files changed

+80
-47
lines changed

2 files changed

+80
-47
lines changed

python/paddle/fluid/tests/unittests/test_dist_transpiler.py

Lines changed: 74 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -411,12 +411,12 @@ def network_with_table(self, is_sparse, is_distributed):
411411
self.emb_size = 64
412412
self.lookup_table_name = 'shared_w'
413413

414-
def emb_pool(ids):
414+
def emb_pool(ids, table_name, is_distributed):
415415
emb = fluid.layers.embedding(
416416
input=ids,
417417
size=[self.table_size, self.emb_size],
418418
dtype='float32',
419-
param_attr=self.lookup_table_name, # share parameter
419+
param_attr=table_name,
420420
is_sparse=is_sparse,
421421
is_distributed=is_distributed)
422422
pool = fluid.layers.sequence_pool(input=emb, pool_type='average')
@@ -426,9 +426,13 @@ def emb_pool(ids):
426426
name='title_ids', shape=[1], dtype='int64', lod_level=1)
427427
brand_ids = fluid.layers.data(
428428
name='brand_ids', shape=[1], dtype='int64', lod_level=1)
429-
title_emb = emb_pool(title_ids)
430-
brand_emb = emb_pool(brand_ids)
431-
fc0 = fluid.layers.concat(input=[title_emb, brand_emb], axis=1)
429+
profile_ids = fluid.layers.data(
430+
name='brand_ids', shape=[1], dtype='int64', lod_level=1)
431+
title_emb = emb_pool(title_ids, self.lookup_table_name, is_distributed)
432+
brand_emb = emb_pool(brand_ids, self.lookup_table_name, is_distributed)
433+
profile_emb = emb_pool(profile_ids, "profile_emb", False)
434+
fc0 = fluid.layers.concat(
435+
input=[title_emb, brand_emb, profile_emb], axis=1)
432436
predict = fluid.layers.fc(input=fc0,
433437
size=2,
434438
act=None,
@@ -449,7 +453,7 @@ def net_conf(self):
449453
def transpiler_test_impl(self):
450454
pserver1, startup1 = self.get_pserver(self.pserver1_ep)
451455

452-
self.assertEqual(len(pserver1.blocks), 3)
456+
self.assertEqual(len(pserver1.blocks), 4)
453457
# 0 listen_and_serv
454458
# 1 optimize for fc_w or fc_b adam
455459
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
@@ -459,16 +463,23 @@ def transpiler_test_impl(self):
459463
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
460464
["sum", "scale", "adam", "scale", "scale"])
461465

466+
# 3 optimize for table 2 adam
467+
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
468+
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
469+
["sum", "scale", "adam", "scale", "scale"])
470+
462471
trainer, _ = self.get_trainer()
463472
self.assertEqual(len(trainer.blocks), 1)
464473
ops = [
465474
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
466-
'concat', 'mul', 'elementwise_add', 'cross_entropy', 'mean',
467-
'fill_constant', 'mean_grad', 'cross_entropy_grad',
468-
'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad',
469-
'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
470-
'lookup_table_grad', 'sum', 'split_selected_rows', 'send',
471-
'send_barrier', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat'
475+
'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add',
476+
'cross_entropy', 'mean', 'fill_constant', 'mean_grad',
477+
'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad',
478+
'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
479+
'split_selected_rows', 'send', 'sequence_pool_grad',
480+
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
481+
'sum', 'split_selected_rows', 'send', 'send_barrier', 'recv',
482+
'recv', 'recv', 'recv', 'fetch_barrier', 'concat', 'concat'
472483
]
473484
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
474485

@@ -480,39 +491,45 @@ def net_conf(self):
480491
def transpiler_test_impl(self):
481492
pserver1, startup1 = self.get_pserver(self.pserver1_ep)
482493

483-
self.assertEqual(len(pserver1.blocks), 5)
494+
self.assertEqual(len(pserver1.blocks), 6)
484495
# 0 listen_and_serv
485496
# 1 optimize for fc_w or fc_b adam
486497
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
487498
["sum", "scale", "adam", "scale", "scale"])
488-
# 2 optimize for table sgd
499+
# 4 prefetch -> lookup_sparse_table for data0
489500
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
501+
["sum", "scale", "adam", "scale", "scale"])
502+
# 2 optimize for table sgd
503+
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
490504
["sum", "sgd"])
491505
# 3 prefetch -> lookup_sparse_table for data0
492-
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
506+
self.assertEqual([op.type for op in pserver1.blocks[4].ops],
493507
["lookup_sparse_table"])
494-
# 4 save table
495-
self.assertEqual([op.type for op in pserver1.blocks[4].ops], ["save"])
508+
# 5 save table
509+
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
496510

497511
trainer, trainer_startup = self.get_trainer()
498512
self.assertEqual(len(trainer.blocks), 1)
499513
ops = [
500514
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
501-
'sequence_pool', 'concat', 'mul', 'elementwise_add',
502-
'cross_entropy', 'mean', 'fill_constant', 'mean_grad',
503-
'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad',
504-
'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
505-
'sequence_pool_grad', 'lookup_table_grad', 'sum', 'split_ids',
506-
'send', 'send_barrier', 'recv', 'recv', 'fetch_barrier'
515+
'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul',
516+
'elementwise_add', 'cross_entropy', 'mean', 'fill_constant',
517+
'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send',
518+
'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad',
519+
'lookup_table_grad', 'split_selected_rows', 'send',
520+
'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
521+
'lookup_table_grad', 'sum', 'split_ids', 'send', 'send_barrier',
522+
'recv', 'recv', 'recv', 'fetch_barrier', 'concat'
507523
]
508524
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
509-
510525
startup_ops = [
511526
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
512527
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
513528
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
514-
'fill_constant', 'fill_constant', 'uniform_random', 'recv', 'recv',
515-
'fetch_barrier', 'fake_init'
529+
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
530+
'fill_constant', 'fill_constant', 'uniform_random',
531+
'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat',
532+
'fake_init'
516533
]
517534
self.assertEqual([op.type for op in trainer_startup.blocks[0].ops],
518535
startup_ops)
@@ -526,7 +543,7 @@ def transpiler_test_impl(self):
526543
config = fluid.DistributeTranspilerConfig()
527544
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
528545

529-
self.assertEqual(len(pserver1.blocks), 3)
546+
self.assertEqual(len(pserver1.blocks), 4)
530547
# 0 listen_and_serv
531548
# 1 optimize for fc_w or fc_b adam
532549
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
@@ -535,17 +552,23 @@ def transpiler_test_impl(self):
535552
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
536553
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
537554
["adam", "scale", "scale"])
555+
# 3 optimize for table adam
556+
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
557+
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
558+
["adam", "scale", "scale"])
538559

539560
trainer, _ = self.get_trainer(config)
540561
self.assertEqual(len(trainer.blocks), 1)
541562
ops = [
542563
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
543-
'concat', 'mul', 'elementwise_add', 'cross_entropy', 'mean',
544-
'fill_constant', 'mean_grad', 'cross_entropy_grad',
545-
'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad',
546-
'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
547-
'lookup_table_grad', 'sum', 'split_selected_rows', 'send', 'recv',
548-
'recv', 'recv', 'concat'
564+
'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add',
565+
'cross_entropy', 'mean', 'fill_constant', 'mean_grad',
566+
'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad',
567+
'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
568+
'split_selected_rows', 'send', 'sequence_pool_grad',
569+
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
570+
'sum', 'split_selected_rows', 'send', 'recv', 'recv', 'recv',
571+
'recv', 'concat', 'concat'
549572
]
550573
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
551574

@@ -559,29 +582,34 @@ def transpiler_test_impl(self):
559582

560583
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
561584

562-
self.assertEqual(len(pserver1.blocks), 5)
585+
self.assertEqual(len(pserver1.blocks), 6)
563586
# 0 listen_and_serv
564587
# 1 optimize for fc_w or fc_b adam
565588
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
566589
["adam", "scale", "scale"])
567-
# 2 optimize for table sgd
568-
self.assertEqual([op.type for op in pserver1.blocks[2].ops], ["sgd"])
569-
# 3 prefetch -> lookup_sparse_table for data0
570-
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
590+
# 2 optimize for table adam
591+
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
592+
["adam", "scale", "scale"])
593+
# 3 optimize for table sgd
594+
self.assertEqual([op.type for op in pserver1.blocks[3].ops], ["sgd"])
595+
# 4 prefetch -> lookup_sparse_table for data0
596+
self.assertEqual([op.type for op in pserver1.blocks[4].ops],
571597
["lookup_sparse_table"])
572-
# 4 save table
573-
self.assertEqual([op.type for op in pserver1.blocks[4].ops], ["save"])
598+
# 5 save table
599+
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
574600

575601
trainer, _ = self.get_trainer(config)
576602
self.assertEqual(len(trainer.blocks), 1)
577603
ops = [
578604
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
579-
'sequence_pool', 'concat', 'mul', 'elementwise_add',
580-
'cross_entropy', 'mean', 'fill_constant', 'mean_grad',
581-
'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad',
582-
'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
583-
'sequence_pool_grad', 'lookup_table_grad', 'sum', 'split_ids',
584-
'send', 'recv', 'recv'
605+
'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul',
606+
'elementwise_add', 'cross_entropy', 'mean', 'fill_constant',
607+
'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send',
608+
'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad',
609+
'lookup_table_grad', 'split_selected_rows', 'send',
610+
'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
611+
'lookup_table_grad', 'sum', 'split_ids', 'send', 'recv', 'recv',
612+
'recv', 'concat'
585613
]
586614
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
587615

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1065,7 +1065,12 @@ def _replace_lookup_table_op_with_prefetch(self, program,
10651065
continue_search_lookup_table_op = False
10661066
all_ops = program.global_block().ops
10671067
for op in all_ops:
1068-
if op.type == LOOKUP_TABLE_TYPE:
1068+
if op.type == LOOKUP_TABLE_TYPE and self.table_name == op.input(
1069+
"W")[0]:
1070+
if not op.attr('is_distributed'):
1071+
raise RuntimeError(
1072+
"lookup_table_op that lookup an distributed embedding table"
1073+
"should set is_distributed to true")
10691074
continue_search_lookup_table_op = True
10701075

10711076
lookup_table_op_index = lookup_table_op_index if lookup_table_op_index != -1 else list(

0 commit comments

Comments
 (0)