@@ -464,6 +464,39 @@ def transpiler_test_impl(self):
464
464
self .assertEqual ([op .type for op in trainer .blocks [0 ].ops ], ops )
465
465
466
466
467
+ class TestAsyncLocalLookupTable (TestDistLookupTableBase ):
468
+ def net_conf (self ):
469
+ self .network_with_table (is_sparse = True , is_distributed = False )
470
+
471
+ def transpiler_test_impl (self ):
472
+ config = fluid .DistributeTranspilerConfig ()
473
+ config .sync_mode = False
474
+ pserver1 , startup1 = self .get_pserver (self .pserver1_ep , config )
475
+
476
+ self .assertEqual (len (pserver1 .blocks ), 3 )
477
+ # 0 listen_and_serv
478
+ # 1 optimize for fc_w or fc_b adam
479
+ self .assertEqual ([op .type for op in pserver1 .blocks [1 ].ops ],
480
+ ["adam" , "scale" , "scale" ])
481
+ # 2 optimize for table adam
482
+ # NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
483
+ self .assertEqual ([op .type for op in pserver1 .blocks [2 ].ops ],
484
+ ["adam" , "scale" , "scale" ])
485
+
486
+ trainer = self .get_trainer (config )
487
+ self .assertEqual (len (trainer .blocks ), 1 )
488
+ ops = [
489
+ 'lookup_table' , 'sequence_pool' , 'lookup_table' , 'sequence_pool' ,
490
+ 'concat' , 'mul' , 'elementwise_add' , 'cross_entropy' , 'mean' ,
491
+ 'fill_constant' , 'mean_grad' , 'cross_entropy_grad' ,
492
+ 'elementwise_add_grad' , 'send' , 'mul_grad' , 'send' , 'concat_grad' ,
493
+ 'sequence_pool_grad' , 'lookup_table_grad' , 'sequence_pool_grad' ,
494
+ 'lookup_table_grad' , 'sum' , 'split_selected_rows' , 'send' , 'recv' ,
495
+ 'recv' , 'recv' , 'concat'
496
+ ]
497
+ self .assertEqual ([op .type for op in trainer .blocks [0 ].ops ], ops )
498
+
499
+
467
500
class TestAsyncDistLookupTable (TestDistLookupTableBase ):
468
501
def net_conf (self ):
469
502
self .network_with_table (is_sparse = True , is_distributed = True )
0 commit comments