@@ -411,12 +411,12 @@ def network_with_table(self, is_sparse, is_distributed):
411
411
self .emb_size = 64
412
412
self .lookup_table_name = 'shared_w'
413
413
414
- def emb_pool (ids ):
414
+ def emb_pool (ids , table_name , is_distributed ):
415
415
emb = fluid .layers .embedding (
416
416
input = ids ,
417
417
size = [self .table_size , self .emb_size ],
418
418
dtype = 'float32' ,
419
- param_attr = self . lookup_table_name , # share parameter
419
+ param_attr = table_name ,
420
420
is_sparse = is_sparse ,
421
421
is_distributed = is_distributed )
422
422
pool = fluid .layers .sequence_pool (input = emb , pool_type = 'average' )
@@ -426,9 +426,12 @@ def emb_pool(ids):
426
426
name = 'title_ids' , shape = [1 ], dtype = 'int64' , lod_level = 1 )
427
427
brand_ids = fluid .layers .data (
428
428
name = 'brand_ids' , shape = [1 ], dtype = 'int64' , lod_level = 1 )
429
- title_emb = emb_pool (title_ids )
430
- brand_emb = emb_pool (brand_ids )
431
- fc0 = fluid .layers .concat (input = [title_emb , brand_emb ], axis = 1 )
429
+ profile_ids = fluid .layers .data (
430
+ name = 'brand_ids' , shape = [1 ], dtype = 'int64' , lod_level = 1 )
431
+ title_emb = emb_pool (title_ids , self .lookup_table_name , is_distributed )
432
+ brand_emb = emb_pool (brand_ids , self .lookup_table_name , is_distributed )
433
+ profile_emb = emb_pool (profile_ids , "profile_emb" , False )
434
+ fc0 = fluid .layers .concat (input = [title_emb , brand_emb , profile_emb ], axis = 1 )
432
435
predict = fluid .layers .fc (input = fc0 ,
433
436
size = 2 ,
434
437
act = None ,
@@ -449,7 +452,7 @@ def net_conf(self):
449
452
def transpiler_test_impl (self ):
450
453
pserver1 , startup1 = self .get_pserver (self .pserver1_ep )
451
454
452
- self .assertEqual (len (pserver1 .blocks ), 3 )
455
+ self .assertEqual (len (pserver1 .blocks ), 4 )
453
456
# 0 listen_and_serv
454
457
# 1 optimize for fc_w or fc_b adam
455
458
self .assertEqual ([op .type for op in pserver1 .blocks [1 ].ops ],
@@ -459,16 +462,22 @@ def transpiler_test_impl(self):
459
462
self .assertEqual ([op .type for op in pserver1 .blocks [2 ].ops ],
460
463
["sum" , "scale" , "adam" , "scale" , "scale" ])
461
464
465
+ # 3 optimize for table 2 adam
466
+ # NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
467
+ self .assertEqual ([op .type for op in pserver1 .blocks [3 ].ops ],
468
+ ["sum" , "scale" , "adam" , "scale" , "scale" ])
469
+
462
470
trainer , _ = self .get_trainer ()
463
471
self .assertEqual (len (trainer .blocks ), 1 )
464
472
ops = [
465
473
'lookup_table' , 'sequence_pool' , 'lookup_table' , 'sequence_pool' ,
466
- 'concat' , 'mul' , 'elementwise_add' , 'cross_entropy' , 'mean' ,
474
+ 'lookup_table' , 'sequence_pool' , ' concat' , 'mul' , 'elementwise_add' , 'cross_entropy' , 'mean' ,
467
475
'fill_constant' , 'mean_grad' , 'cross_entropy_grad' ,
468
476
'elementwise_add_grad' , 'send' , 'mul_grad' , 'send' , 'concat_grad' ,
477
+ 'sequence_pool_grad' , 'lookup_table_grad' , 'split_selected_rows' , 'send' ,
469
478
'sequence_pool_grad' , 'lookup_table_grad' , 'sequence_pool_grad' ,
470
479
'lookup_table_grad' , 'sum' , 'split_selected_rows' , 'send' ,
471
- 'send_barrier' , 'recv' , 'recv' , 'recv' , 'fetch_barrier' , 'concat'
480
+ 'send_barrier' , 'recv' , 'recv' , 'recv' , 'recv' , ' fetch_barrier' , 'concat ' , 'concat'
472
481
]
473
482
self .assertEqual ([op .type for op in trainer .blocks [0 ].ops ], ops )
474
483
@@ -480,40 +489,42 @@ def net_conf(self):
480
489
def transpiler_test_impl (self ):
481
490
pserver1 , startup1 = self .get_pserver (self .pserver1_ep )
482
491
483
- self .assertEqual (len (pserver1 .blocks ), 5 )
492
+ self .assertEqual (len (pserver1 .blocks ), 6 )
484
493
# 0 listen_and_serv
485
494
# 1 optimize for fc_w or fc_b adam
486
495
self .assertEqual ([op .type for op in pserver1 .blocks [1 ].ops ],
487
496
["sum" , "scale" , "adam" , "scale" , "scale" ])
488
- # 2 optimize for table sgd
497
+ # 4 prefetch -> lookup_sparse_table for data0
489
498
self .assertEqual ([op .type for op in pserver1 .blocks [2 ].ops ],
499
+ ["sum" , "scale" , "adam" , "scale" , "scale" ])
500
+ # 2 optimize for table sgd
501
+ self .assertEqual ([op .type for op in pserver1 .blocks [3 ].ops ],
490
502
["sum" , "sgd" ])
491
503
# 3 prefetch -> lookup_sparse_table for data0
492
- self .assertEqual ([op .type for op in pserver1 .blocks [3 ].ops ],
504
+ self .assertEqual ([op .type for op in pserver1 .blocks [4 ].ops ],
493
505
["lookup_sparse_table" ])
494
- # 4 save table
495
- self .assertEqual ([op .type for op in pserver1 .blocks [4 ].ops ], ["save" ])
506
+ # 5 save table
507
+ self .assertEqual ([op .type for op in pserver1 .blocks [5 ].ops ], ["save" ])
496
508
497
509
trainer , trainer_startup = self .get_trainer ()
498
510
self .assertEqual (len (trainer .blocks ), 1 )
499
511
ops = [
500
512
'split_ids' , 'prefetch' , 'merge_ids' , 'sequence_pool' ,
501
- 'sequence_pool' , 'concat' , 'mul' , 'elementwise_add' ,
513
+ 'sequence_pool' , 'lookup_table' , 'sequence_pool' , ' concat' , 'mul' , 'elementwise_add' ,
502
514
'cross_entropy' , 'mean' , 'fill_constant' , 'mean_grad' ,
503
515
'cross_entropy_grad' , 'elementwise_add_grad' , 'send' , 'mul_grad' ,
504
516
'send' , 'concat_grad' , 'sequence_pool_grad' , 'lookup_table_grad' ,
505
- 'sequence_pool_grad ' , 'lookup_table_grad ' , 'sum ' , 'split_ids ' ,
506
- 'send ' , 'send_barrier ' , 'recv ' , 'recv ' , 'fetch_barrier'
507
- ]
517
+ 'split_selected_rows ' , 'send ' , 'sequence_pool_grad ' , 'lookup_table_grad ' ,
518
+ 'sequence_pool_grad ' , 'lookup_table_grad ' , 'sum ' , 'split_ids ' , 'send' , 'send_barrier' ,
519
+ 'recv' , 'recv' , 'recv' , 'fetch_barrier' , 'concat' ]
508
520
self .assertEqual ([op .type for op in trainer .blocks [0 ].ops ], ops )
509
-
510
521
startup_ops = [
511
522
'fill_constant' , 'fill_constant' , 'fill_constant' , 'fill_constant' ,
512
523
'fill_constant' , 'fill_constant' , 'fill_constant' , 'fill_constant' ,
513
524
'fill_constant' , 'fill_constant' , 'fill_constant' , 'fill_constant' ,
514
- 'fill_constant' , 'fill_constant' , 'uniform_random ' , 'recv' , 'recv ' ,
515
- 'fetch_barrier ' , 'fake_init'
516
- ]
525
+ 'fill_constant' , 'fill_constant' , 'fill_constant ' , 'fill_constant ' ,
526
+ 'fill_constant ' , 'fill_constant' , 'uniform_random' , 'uniform_random' ,
527
+ 'recv' , 'recv' , 'recv' , 'fetch_barrier' , 'concat' , 'fake_init' ]
517
528
self .assertEqual ([op .type for op in trainer_startup .blocks [0 ].ops ],
518
529
startup_ops )
519
530
@@ -526,7 +537,7 @@ def transpiler_test_impl(self):
526
537
config = fluid .DistributeTranspilerConfig ()
527
538
pserver1 , startup1 = self .get_pserver (self .pserver1_ep , config , False )
528
539
529
- self .assertEqual (len (pserver1 .blocks ), 3 )
540
+ self .assertEqual (len (pserver1 .blocks ), 4 )
530
541
# 0 listen_and_serv
531
542
# 1 optimize for fc_w or fc_b adam
532
543
self .assertEqual ([op .type for op in pserver1 .blocks [1 ].ops ],
@@ -535,17 +546,24 @@ def transpiler_test_impl(self):
535
546
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
536
547
self .assertEqual ([op .type for op in pserver1 .blocks [2 ].ops ],
537
548
["adam" , "scale" , "scale" ])
549
+ # 3 optimize for table adam
550
+ # NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
551
+ self .assertEqual ([op .type for op in pserver1 .blocks [3 ].ops ],
552
+ ["adam" , "scale" , "scale" ])
538
553
539
554
trainer , _ = self .get_trainer (config )
540
555
self .assertEqual (len (trainer .blocks ), 1 )
541
556
ops = [
542
557
'lookup_table' , 'sequence_pool' , 'lookup_table' , 'sequence_pool' ,
543
- 'concat' , 'mul' , 'elementwise_add' , 'cross_entropy' , 'mean' ,
544
- 'fill_constant' , 'mean_grad' , 'cross_entropy_grad' ,
545
- 'elementwise_add_grad' , 'send' , 'mul_grad' , 'send' , 'concat_grad' ,
546
- 'sequence_pool_grad' , 'lookup_table_grad' , 'sequence_pool_grad' ,
547
- 'lookup_table_grad' , 'sum' , 'split_selected_rows' , 'send' , 'recv' ,
548
- 'recv' , 'recv' , 'concat'
558
+ 'lookup_table' , 'sequence_pool' , 'concat' , 'mul' , 'elementwise_add' ,
559
+ 'cross_entropy' , 'mean' , 'fill_constant' , 'mean_grad' ,
560
+ 'cross_entropy_grad' , 'elementwise_add_grad' , 'send' ,
561
+ 'mul_grad' , 'send' , 'concat_grad' , 'sequence_pool_grad' ,
562
+ 'lookup_table_grad' , 'split_selected_rows' , 'send' ,
563
+ 'sequence_pool_grad' , 'lookup_table_grad' ,
564
+ 'sequence_pool_grad' , 'lookup_table_grad' ,
565
+ 'sum' , 'split_selected_rows' , 'send' , 'recv' , 'recv' , 'recv' , 'recv' ,
566
+ 'concat' , 'concat'
549
567
]
550
568
self .assertEqual ([op .type for op in trainer .blocks [0 ].ops ], ops )
551
569
@@ -559,30 +577,34 @@ def transpiler_test_impl(self):
559
577
560
578
pserver1 , startup1 = self .get_pserver (self .pserver1_ep , config , False )
561
579
562
- self .assertEqual (len (pserver1 .blocks ), 5 )
580
+ self .assertEqual (len (pserver1 .blocks ), 6 )
563
581
# 0 listen_and_serv
564
582
# 1 optimize for fc_w or fc_b adam
565
583
self .assertEqual ([op .type for op in pserver1 .blocks [1 ].ops ],
566
584
["adam" , "scale" , "scale" ])
567
- # 2 optimize for table sgd
568
- self .assertEqual ([op .type for op in pserver1 .blocks [2 ].ops ], ["sgd" ])
569
- # 3 prefetch -> lookup_sparse_table for data0
570
- self .assertEqual ([op .type for op in pserver1 .blocks [3 ].ops ],
585
+ # 2 optimize for table adam
586
+ self .assertEqual ([op .type for op in pserver1 .blocks [2 ].ops ],
587
+ ["adam" , "scale" , "scale" ])
588
+ # 3 optimize for table sgd
589
+ self .assertEqual ([op .type for op in pserver1 .blocks [3 ].ops ], ["sgd" ])
590
+ # 4 prefetch -> lookup_sparse_table for data0
591
+ self .assertEqual ([op .type for op in pserver1 .blocks [4 ].ops ],
571
592
["lookup_sparse_table" ])
572
- # 4 save table
573
- self .assertEqual ([op .type for op in pserver1 .blocks [4 ].ops ], ["save" ])
593
+ # 5 save table
594
+ self .assertEqual ([op .type for op in pserver1 .blocks [5 ].ops ], ["save" ])
574
595
575
596
trainer , _ = self .get_trainer (config )
576
597
self .assertEqual (len (trainer .blocks ), 1 )
577
598
ops = [
578
599
'split_ids' , 'prefetch' , 'merge_ids' , 'sequence_pool' ,
579
- 'sequence_pool' , 'concat' , 'mul' , 'elementwise_add' ,
580
- 'cross_entropy' , 'mean' , 'fill_constant' , 'mean_grad' ,
581
- 'cross_entropy_grad' , 'elementwise_add_grad' , 'send' , 'mul_grad' ,
582
- 'send' , 'concat_grad' , 'sequence_pool_grad' , 'lookup_table_grad' ,
583
- 'sequence_pool_grad' , 'lookup_table_grad' , 'sum' , 'split_ids' ,
584
- 'send' , 'recv' , 'recv'
585
- ]
600
+ 'sequence_pool' , 'lookup_table' , 'sequence_pool' ,
601
+ 'concat' , 'mul' , 'elementwise_add' , 'cross_entropy' ,
602
+ 'mean' , 'fill_constant' , 'mean_grad' , 'cross_entropy_grad' ,
603
+ 'elementwise_add_grad' , 'send' , 'mul_grad' , 'send' ,
604
+ 'concat_grad' , 'sequence_pool_grad' , 'lookup_table_grad' ,
605
+ 'split_selected_rows' , 'send' , 'sequence_pool_grad' ,
606
+ 'lookup_table_grad' , 'sequence_pool_grad' , 'lookup_table_grad' ,
607
+ 'sum' , 'split_ids' , 'send' , 'recv' , 'recv' , 'recv' , 'concat' ]
586
608
self .assertEqual ([op .type for op in trainer .blocks [0 ].ops ], ops )
587
609
588
610
0 commit comments