@@ -411,12 +411,12 @@ def network_with_table(self, is_sparse, is_distributed):
411
411
self .emb_size = 64
412
412
self .lookup_table_name = 'shared_w'
413
413
414
- def emb_pool (ids ):
414
+ def emb_pool (ids , table_name , is_distributed ):
415
415
emb = fluid .layers .embedding (
416
416
input = ids ,
417
417
size = [self .table_size , self .emb_size ],
418
418
dtype = 'float32' ,
419
- param_attr = self . lookup_table_name , # share parameter
419
+ param_attr = table_name ,
420
420
is_sparse = is_sparse ,
421
421
is_distributed = is_distributed )
422
422
pool = fluid .layers .sequence_pool (input = emb , pool_type = 'average' )
@@ -426,9 +426,13 @@ def emb_pool(ids):
426
426
name = 'title_ids' , shape = [1 ], dtype = 'int64' , lod_level = 1 )
427
427
brand_ids = fluid .layers .data (
428
428
name = 'brand_ids' , shape = [1 ], dtype = 'int64' , lod_level = 1 )
429
- title_emb = emb_pool (title_ids )
430
- brand_emb = emb_pool (brand_ids )
431
- fc0 = fluid .layers .concat (input = [title_emb , brand_emb ], axis = 1 )
429
+ profile_ids = fluid .layers .data (
430
+ name = 'brand_ids' , shape = [1 ], dtype = 'int64' , lod_level = 1 )
431
+ title_emb = emb_pool (title_ids , self .lookup_table_name , is_distributed )
432
+ brand_emb = emb_pool (brand_ids , self .lookup_table_name , is_distributed )
433
+ profile_emb = emb_pool (profile_ids , "profile_emb" , False )
434
+ fc0 = fluid .layers .concat (
435
+ input = [title_emb , brand_emb , profile_emb ], axis = 1 )
432
436
predict = fluid .layers .fc (input = fc0 ,
433
437
size = 2 ,
434
438
act = None ,
@@ -449,7 +453,7 @@ def net_conf(self):
449
453
def transpiler_test_impl (self ):
450
454
pserver1 , startup1 = self .get_pserver (self .pserver1_ep )
451
455
452
- self .assertEqual (len (pserver1 .blocks ), 3 )
456
+ self .assertEqual (len (pserver1 .blocks ), 4 )
453
457
# 0 listen_and_serv
454
458
# 1 optimize for fc_w or fc_b adam
455
459
self .assertEqual ([op .type for op in pserver1 .blocks [1 ].ops ],
@@ -459,16 +463,23 @@ def transpiler_test_impl(self):
459
463
self .assertEqual ([op .type for op in pserver1 .blocks [2 ].ops ],
460
464
["sum" , "scale" , "adam" , "scale" , "scale" ])
461
465
466
+ # 3 optimize for table 2 adam
467
+ # NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
468
+ self .assertEqual ([op .type for op in pserver1 .blocks [3 ].ops ],
469
+ ["sum" , "scale" , "adam" , "scale" , "scale" ])
470
+
462
471
trainer , _ = self .get_trainer ()
463
472
self .assertEqual (len (trainer .blocks ), 1 )
464
473
ops = [
465
474
'lookup_table' , 'sequence_pool' , 'lookup_table' , 'sequence_pool' ,
466
- 'concat' , 'mul' , 'elementwise_add' , 'cross_entropy' , 'mean' ,
467
- 'fill_constant' , 'mean_grad' , 'cross_entropy_grad' ,
468
- 'elementwise_add_grad' , 'send' , 'mul_grad' , 'send' , 'concat_grad' ,
469
- 'sequence_pool_grad' , 'lookup_table_grad' , 'sequence_pool_grad' ,
470
- 'lookup_table_grad' , 'sum' , 'split_selected_rows' , 'send' ,
471
- 'send_barrier' , 'recv' , 'recv' , 'recv' , 'fetch_barrier' , 'concat'
475
+ 'lookup_table' , 'sequence_pool' , 'concat' , 'mul' , 'elementwise_add' ,
476
+ 'cross_entropy' , 'mean' , 'fill_constant' , 'mean_grad' ,
477
+ 'cross_entropy_grad' , 'elementwise_add_grad' , 'send' , 'mul_grad' ,
478
+ 'send' , 'concat_grad' , 'sequence_pool_grad' , 'lookup_table_grad' ,
479
+ 'split_selected_rows' , 'send' , 'sequence_pool_grad' ,
480
+ 'lookup_table_grad' , 'sequence_pool_grad' , 'lookup_table_grad' ,
481
+ 'sum' , 'split_selected_rows' , 'send' , 'send_barrier' , 'recv' ,
482
+ 'recv' , 'recv' , 'recv' , 'fetch_barrier' , 'concat' , 'concat'
472
483
]
473
484
self .assertEqual ([op .type for op in trainer .blocks [0 ].ops ], ops )
474
485
@@ -480,39 +491,45 @@ def net_conf(self):
480
491
def transpiler_test_impl (self ):
481
492
pserver1 , startup1 = self .get_pserver (self .pserver1_ep )
482
493
483
- self .assertEqual (len (pserver1 .blocks ), 5 )
494
+ self .assertEqual (len (pserver1 .blocks ), 6 )
484
495
# 0 listen_and_serv
485
496
# 1 optimize for fc_w or fc_b adam
486
497
self .assertEqual ([op .type for op in pserver1 .blocks [1 ].ops ],
487
498
["sum" , "scale" , "adam" , "scale" , "scale" ])
488
- # 2 optimize for table sgd
499
+ # 4 prefetch -> lookup_sparse_table for data0
489
500
self .assertEqual ([op .type for op in pserver1 .blocks [2 ].ops ],
501
+ ["sum" , "scale" , "adam" , "scale" , "scale" ])
502
+ # 2 optimize for table sgd
503
+ self .assertEqual ([op .type for op in pserver1 .blocks [3 ].ops ],
490
504
["sum" , "sgd" ])
491
505
# 3 prefetch -> lookup_sparse_table for data0
492
- self .assertEqual ([op .type for op in pserver1 .blocks [3 ].ops ],
506
+ self .assertEqual ([op .type for op in pserver1 .blocks [4 ].ops ],
493
507
["lookup_sparse_table" ])
494
- # 4 save table
495
- self .assertEqual ([op .type for op in pserver1 .blocks [4 ].ops ], ["save" ])
508
+ # 5 save table
509
+ self .assertEqual ([op .type for op in pserver1 .blocks [5 ].ops ], ["save" ])
496
510
497
511
trainer , trainer_startup = self .get_trainer ()
498
512
self .assertEqual (len (trainer .blocks ), 1 )
499
513
ops = [
500
514
'split_ids' , 'prefetch' , 'merge_ids' , 'sequence_pool' ,
501
- 'sequence_pool' , 'concat' , 'mul' , 'elementwise_add' ,
502
- 'cross_entropy' , 'mean' , 'fill_constant' , 'mean_grad' ,
503
- 'cross_entropy_grad' , 'elementwise_add_grad' , 'send' , 'mul_grad' ,
504
- 'send' , 'concat_grad' , 'sequence_pool_grad' , 'lookup_table_grad' ,
505
- 'sequence_pool_grad' , 'lookup_table_grad' , 'sum' , 'split_ids' ,
506
- 'send' , 'send_barrier' , 'recv' , 'recv' , 'fetch_barrier'
515
+ 'sequence_pool' , 'lookup_table' , 'sequence_pool' , 'concat' , 'mul' ,
516
+ 'elementwise_add' , 'cross_entropy' , 'mean' , 'fill_constant' ,
517
+ 'mean_grad' , 'cross_entropy_grad' , 'elementwise_add_grad' , 'send' ,
518
+ 'mul_grad' , 'send' , 'concat_grad' , 'sequence_pool_grad' ,
519
+ 'lookup_table_grad' , 'split_selected_rows' , 'send' ,
520
+ 'sequence_pool_grad' , 'lookup_table_grad' , 'sequence_pool_grad' ,
521
+ 'lookup_table_grad' , 'sum' , 'split_ids' , 'send' , 'send_barrier' ,
522
+ 'recv' , 'recv' , 'recv' , 'fetch_barrier' , 'concat'
507
523
]
508
524
self .assertEqual ([op .type for op in trainer .blocks [0 ].ops ], ops )
509
-
510
525
startup_ops = [
511
526
'fill_constant' , 'fill_constant' , 'fill_constant' , 'fill_constant' ,
512
527
'fill_constant' , 'fill_constant' , 'fill_constant' , 'fill_constant' ,
513
528
'fill_constant' , 'fill_constant' , 'fill_constant' , 'fill_constant' ,
514
- 'fill_constant' , 'fill_constant' , 'uniform_random' , 'recv' , 'recv' ,
515
- 'fetch_barrier' , 'fake_init'
529
+ 'fill_constant' , 'fill_constant' , 'fill_constant' , 'fill_constant' ,
530
+ 'fill_constant' , 'fill_constant' , 'uniform_random' ,
531
+ 'uniform_random' , 'recv' , 'recv' , 'recv' , 'fetch_barrier' , 'concat' ,
532
+ 'fake_init'
516
533
]
517
534
self .assertEqual ([op .type for op in trainer_startup .blocks [0 ].ops ],
518
535
startup_ops )
@@ -526,7 +543,7 @@ def transpiler_test_impl(self):
526
543
config = fluid .DistributeTranspilerConfig ()
527
544
pserver1 , startup1 = self .get_pserver (self .pserver1_ep , config , False )
528
545
529
- self .assertEqual (len (pserver1 .blocks ), 3 )
546
+ self .assertEqual (len (pserver1 .blocks ), 4 )
530
547
# 0 listen_and_serv
531
548
# 1 optimize for fc_w or fc_b adam
532
549
self .assertEqual ([op .type for op in pserver1 .blocks [1 ].ops ],
@@ -535,17 +552,23 @@ def transpiler_test_impl(self):
535
552
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
536
553
self .assertEqual ([op .type for op in pserver1 .blocks [2 ].ops ],
537
554
["adam" , "scale" , "scale" ])
555
+ # 3 optimize for table adam
556
+ # NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
557
+ self .assertEqual ([op .type for op in pserver1 .blocks [3 ].ops ],
558
+ ["adam" , "scale" , "scale" ])
538
559
539
560
trainer , _ = self .get_trainer (config )
540
561
self .assertEqual (len (trainer .blocks ), 1 )
541
562
ops = [
542
563
'lookup_table' , 'sequence_pool' , 'lookup_table' , 'sequence_pool' ,
543
- 'concat' , 'mul' , 'elementwise_add' , 'cross_entropy' , 'mean' ,
544
- 'fill_constant' , 'mean_grad' , 'cross_entropy_grad' ,
545
- 'elementwise_add_grad' , 'send' , 'mul_grad' , 'send' , 'concat_grad' ,
546
- 'sequence_pool_grad' , 'lookup_table_grad' , 'sequence_pool_grad' ,
547
- 'lookup_table_grad' , 'sum' , 'split_selected_rows' , 'send' , 'recv' ,
548
- 'recv' , 'recv' , 'concat'
564
+ 'lookup_table' , 'sequence_pool' , 'concat' , 'mul' , 'elementwise_add' ,
565
+ 'cross_entropy' , 'mean' , 'fill_constant' , 'mean_grad' ,
566
+ 'cross_entropy_grad' , 'elementwise_add_grad' , 'send' , 'mul_grad' ,
567
+ 'send' , 'concat_grad' , 'sequence_pool_grad' , 'lookup_table_grad' ,
568
+ 'split_selected_rows' , 'send' , 'sequence_pool_grad' ,
569
+ 'lookup_table_grad' , 'sequence_pool_grad' , 'lookup_table_grad' ,
570
+ 'sum' , 'split_selected_rows' , 'send' , 'recv' , 'recv' , 'recv' ,
571
+ 'recv' , 'concat' , 'concat'
549
572
]
550
573
self .assertEqual ([op .type for op in trainer .blocks [0 ].ops ], ops )
551
574
@@ -559,29 +582,34 @@ def transpiler_test_impl(self):
559
582
560
583
pserver1 , startup1 = self .get_pserver (self .pserver1_ep , config , False )
561
584
562
- self .assertEqual (len (pserver1 .blocks ), 5 )
585
+ self .assertEqual (len (pserver1 .blocks ), 6 )
563
586
# 0 listen_and_serv
564
587
# 1 optimize for fc_w or fc_b adam
565
588
self .assertEqual ([op .type for op in pserver1 .blocks [1 ].ops ],
566
589
["adam" , "scale" , "scale" ])
567
- # 2 optimize for table sgd
568
- self .assertEqual ([op .type for op in pserver1 .blocks [2 ].ops ], ["sgd" ])
569
- # 3 prefetch -> lookup_sparse_table for data0
570
- self .assertEqual ([op .type for op in pserver1 .blocks [3 ].ops ],
590
+ # 2 optimize for table adam
591
+ self .assertEqual ([op .type for op in pserver1 .blocks [2 ].ops ],
592
+ ["adam" , "scale" , "scale" ])
593
+ # 3 optimize for table sgd
594
+ self .assertEqual ([op .type for op in pserver1 .blocks [3 ].ops ], ["sgd" ])
595
+ # 4 prefetch -> lookup_sparse_table for data0
596
+ self .assertEqual ([op .type for op in pserver1 .blocks [4 ].ops ],
571
597
["lookup_sparse_table" ])
572
- # 4 save table
573
- self .assertEqual ([op .type for op in pserver1 .blocks [4 ].ops ], ["save" ])
598
+ # 5 save table
599
+ self .assertEqual ([op .type for op in pserver1 .blocks [5 ].ops ], ["save" ])
574
600
575
601
trainer , _ = self .get_trainer (config )
576
602
self .assertEqual (len (trainer .blocks ), 1 )
577
603
ops = [
578
604
'split_ids' , 'prefetch' , 'merge_ids' , 'sequence_pool' ,
579
- 'sequence_pool' , 'concat' , 'mul' , 'elementwise_add' ,
580
- 'cross_entropy' , 'mean' , 'fill_constant' , 'mean_grad' ,
581
- 'cross_entropy_grad' , 'elementwise_add_grad' , 'send' , 'mul_grad' ,
582
- 'send' , 'concat_grad' , 'sequence_pool_grad' , 'lookup_table_grad' ,
583
- 'sequence_pool_grad' , 'lookup_table_grad' , 'sum' , 'split_ids' ,
584
- 'send' , 'recv' , 'recv'
605
+ 'sequence_pool' , 'lookup_table' , 'sequence_pool' , 'concat' , 'mul' ,
606
+ 'elementwise_add' , 'cross_entropy' , 'mean' , 'fill_constant' ,
607
+ 'mean_grad' , 'cross_entropy_grad' , 'elementwise_add_grad' , 'send' ,
608
+ 'mul_grad' , 'send' , 'concat_grad' , 'sequence_pool_grad' ,
609
+ 'lookup_table_grad' , 'split_selected_rows' , 'send' ,
610
+ 'sequence_pool_grad' , 'lookup_table_grad' , 'sequence_pool_grad' ,
611
+ 'lookup_table_grad' , 'sum' , 'split_ids' , 'send' , 'recv' , 'recv' ,
612
+ 'recv' , 'concat'
585
613
]
586
614
self .assertEqual ([op .type for op in trainer .blocks [0 ].ops ], ops )
587
615
0 commit comments