@@ -464,5 +464,46 @@ def transpiler_test_impl(self):
464
464
self .assertEqual ([op .type for op in trainer .blocks [0 ].ops ], ops )
465
465
466
466
467
+ class TestAsyncDistLookupTable (TestDistLookupTableBase ):
468
+ def net_conf (self ):
469
+ self .network_with_table (is_sparse = True , is_distributed = True )
470
+
471
+ def transpiler_test_impl (self ):
472
+ config = fluid .DistributeTranspilerConfig ()
473
+ config .sync_mode = False
474
+
475
+ pserver1 , startup1 = self .get_pserver (self .pserver1_ep , config )
476
+
477
+ self .assertEqual (len (pserver1 .blocks ), 6 )
478
+ # 0 listen_and_serv
479
+ # 1 optimize for fc_w or fc_b adam
480
+ self .assertEqual ([op .type for op in pserver1 .blocks [1 ].ops ],
481
+ ["adam" , "scale" , "scale" ])
482
+ # 2 optimize for table sgd
483
+ self .assertEqual ([op .type for op in pserver1 .blocks [2 ].ops ], ["sgd" ])
484
+ # 3 prefetch -> lookup_sparse_table for data0
485
+ self .assertEqual ([op .type for op in pserver1 .blocks [3 ].ops ],
486
+ ["lookup_sparse_table" ])
487
+ # 4 prefetch -> lookup_sparse_table for data1
488
+ self .assertEqual ([op .type for op in pserver1 .blocks [4 ].ops ],
489
+ ["lookup_sparse_table" ])
490
+ # 5 save table
491
+ self .assertEqual ([op .type for op in pserver1 .blocks [5 ].ops ], ["save" ])
492
+
493
+ trainer = self .get_trainer (config )
494
+ self .assertEqual (len (trainer .blocks ), 1 )
495
+ print ([op .type for op in trainer .blocks [0 ].ops ])
496
+ ops = [
497
+ 'split_ids' , 'prefetch' , 'merge_ids' , 'sequence_pool' , 'split_ids' ,
498
+ 'prefetch' , 'merge_ids' , 'sequence_pool' , 'concat' , 'mul' ,
499
+ 'elementwise_add' , 'cross_entropy' , 'mean' , 'fill_constant' ,
500
+ 'mean_grad' , 'cross_entropy_grad' , 'elementwise_add_grad' , 'send' ,
501
+ 'mul_grad' , 'send' , 'concat_grad' , 'sequence_pool_grad' ,
502
+ 'lookup_table_grad' , 'sequence_pool_grad' , 'lookup_table_grad' ,
503
+ 'sum' , 'split_ids' , 'send' , 'recv' , 'recv'
504
+ ]
505
+ self .assertEqual ([op .type for op in trainer .blocks [0 ].ops ], ops )
506
+
507
+
467
508
if __name__ == "__main__" :
468
509
unittest .main ()
0 commit comments