@@ -51,25 +51,26 @@ def get_main_program(self):
51
51
self .origin_prog = main .clone ()
52
52
return main
53
53
54
- def get_trainer (self , config = None ):
55
- t = self ._transpiler_instance (config )
54
+ def get_trainer (self , config = None , sync_mode = True ):
55
+ t = self ._transpiler_instance (config , sync_mode )
56
56
return t .get_trainer_program ()
57
57
58
- def get_pserver (self , ep , config = None ):
59
- t = self ._transpiler_instance (config )
58
+ def get_pserver (self , ep , config = None , sync_mode = True ):
59
+ t = self ._transpiler_instance (config , sync_mode )
60
60
pserver = t .get_pserver_program (ep )
61
61
startup = t .get_startup_program (ep , pserver )
62
62
return pserver , startup
63
63
64
- def _transpiler_instance (self , config = None ):
64
+ def _transpiler_instance (self , config = None , sync_mode = True ):
65
65
if not self .transpiler :
66
66
main = self .get_main_program ()
67
67
self .transpiler = fluid .DistributeTranspiler (config = config )
68
68
self .transpiler .transpile (
69
69
self .trainer_id ,
70
70
program = main ,
71
71
pservers = self .pserver_eps ,
72
- trainers = self .trainers )
72
+ trainers = self .trainers ,
73
+ sync_mode = sync_mode )
73
74
74
75
return self .transpiler
75
76
@@ -464,5 +465,76 @@ def transpiler_test_impl(self):
464
465
self .assertEqual ([op .type for op in trainer .blocks [0 ].ops ], ops )
465
466
466
467
468
+ class TestAsyncLocalLookupTable (TestDistLookupTableBase ):
469
+ def net_conf (self ):
470
+ self .network_with_table (is_sparse = True , is_distributed = False )
471
+
472
+ def transpiler_test_impl (self ):
473
+ config = fluid .DistributeTranspilerConfig ()
474
+ pserver1 , startup1 = self .get_pserver (self .pserver1_ep , config , False )
475
+
476
+ self .assertEqual (len (pserver1 .blocks ), 3 )
477
+ # 0 listen_and_serv
478
+ # 1 optimize for fc_w or fc_b adam
479
+ self .assertEqual ([op .type for op in pserver1 .blocks [1 ].ops ],
480
+ ["adam" , "scale" , "scale" ])
481
+ # 2 optimize for table adam
482
+ # NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
483
+ self .assertEqual ([op .type for op in pserver1 .blocks [2 ].ops ],
484
+ ["adam" , "scale" , "scale" ])
485
+
486
+ trainer = self .get_trainer (config )
487
+ self .assertEqual (len (trainer .blocks ), 1 )
488
+ ops = [
489
+ 'lookup_table' , 'sequence_pool' , 'lookup_table' , 'sequence_pool' ,
490
+ 'concat' , 'mul' , 'elementwise_add' , 'cross_entropy' , 'mean' ,
491
+ 'fill_constant' , 'mean_grad' , 'cross_entropy_grad' ,
492
+ 'elementwise_add_grad' , 'send' , 'mul_grad' , 'send' , 'concat_grad' ,
493
+ 'sequence_pool_grad' , 'lookup_table_grad' , 'sequence_pool_grad' ,
494
+ 'lookup_table_grad' , 'sum' , 'split_selected_rows' , 'send' , 'recv' ,
495
+ 'recv' , 'recv' , 'concat'
496
+ ]
497
+ self .assertEqual ([op .type for op in trainer .blocks [0 ].ops ], ops )
498
+
499
+
500
+ class TestAsyncDistLookupTable (TestDistLookupTableBase ):
501
+ def net_conf (self ):
502
+ self .network_with_table (is_sparse = True , is_distributed = True )
503
+
504
+ def transpiler_test_impl (self ):
505
+ config = fluid .DistributeTranspilerConfig ()
506
+
507
+ pserver1 , startup1 = self .get_pserver (self .pserver1_ep , config , False )
508
+
509
+ self .assertEqual (len (pserver1 .blocks ), 6 )
510
+ # 0 listen_and_serv
511
+ # 1 optimize for fc_w or fc_b adam
512
+ self .assertEqual ([op .type for op in pserver1 .blocks [1 ].ops ],
513
+ ["adam" , "scale" , "scale" ])
514
+ # 2 optimize for table sgd
515
+ self .assertEqual ([op .type for op in pserver1 .blocks [2 ].ops ], ["sgd" ])
516
+ # 3 prefetch -> lookup_sparse_table for data0
517
+ self .assertEqual ([op .type for op in pserver1 .blocks [3 ].ops ],
518
+ ["lookup_sparse_table" ])
519
+ # 4 prefetch -> lookup_sparse_table for data1
520
+ self .assertEqual ([op .type for op in pserver1 .blocks [4 ].ops ],
521
+ ["lookup_sparse_table" ])
522
+ # 5 save table
523
+ self .assertEqual ([op .type for op in pserver1 .blocks [5 ].ops ], ["save" ])
524
+
525
+ trainer = self .get_trainer (config )
526
+ self .assertEqual (len (trainer .blocks ), 1 )
527
+ ops = [
528
+ 'split_ids' , 'prefetch' , 'merge_ids' , 'sequence_pool' , 'split_ids' ,
529
+ 'prefetch' , 'merge_ids' , 'sequence_pool' , 'concat' , 'mul' ,
530
+ 'elementwise_add' , 'cross_entropy' , 'mean' , 'fill_constant' ,
531
+ 'mean_grad' , 'cross_entropy_grad' , 'elementwise_add_grad' , 'send' ,
532
+ 'mul_grad' , 'send' , 'concat_grad' , 'sequence_pool_grad' ,
533
+ 'lookup_table_grad' , 'sequence_pool_grad' , 'lookup_table_grad' ,
534
+ 'sum' , 'split_ids' , 'send' , 'recv' , 'recv'
535
+ ]
536
+ self .assertEqual ([op .type for op in trainer .blocks [0 ].ops ], ops )
537
+
538
+
467
539
if __name__ == "__main__" :
468
540
unittest .main ()
0 commit comments