@@ -394,6 +394,38 @@ def emb_pool(ids):
394
394
optimizer .minimize (avg_cost )
395
395
396
396
397
+ class TestLocalLookupTable (TestDistLookupTableBase ):
398
+ def net_conf (self ):
399
+ self .network_with_table (is_sparse = True , is_distributed = False )
400
+
401
+ def transpiler_test_impl (self ):
402
+ pserver1 , startup1 = self .get_pserver (self .pserver1_ep )
403
+
404
+ self .assertEqual (len (pserver1 .blocks ), 3 )
405
+ # print(str(pserver1))
406
+ # 0 listen_and_serv
407
+ # 1 optimize for fc_w or fc_b adam
408
+ self .assertEqual ([op .type for op in pserver1 .blocks [1 ].ops ],
409
+ ["sum" , "scale" , "adam" , "scale" , "scale" ])
410
+ # 2 optimize for table adam
411
+ # NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
412
+ self .assertEqual ([op .type for op in pserver1 .blocks [2 ].ops ],
413
+ ["sum" , "adam" , "scale" , "scale" ])
414
+
415
+ trainer = self .get_trainer ()
416
+ self .assertEqual (len (trainer .blocks ), 1 )
417
+ ops = [
418
+ 'lookup_table' , 'sequence_pool' , 'lookup_table' , 'sequence_pool' ,
419
+ 'concat' , 'mul' , 'elementwise_add' , 'cross_entropy' , 'mean' ,
420
+ 'fill_constant' , 'mean_grad' , 'cross_entropy_grad' ,
421
+ 'elementwise_add_grad' , 'send' , 'mul_grad' , 'send' , 'concat_grad' ,
422
+ 'sequence_pool_grad' , 'lookup_table_grad' , 'sequence_pool_grad' ,
423
+ 'lookup_table_grad' , 'sum' , 'split_selected_rows' , 'send' ,
424
+ 'send_barrier' , 'recv' , 'recv' , 'recv' , 'fetch_barrier' , 'concat'
425
+ ]
426
+ self .assertEqual ([op .type for op in trainer .blocks [0 ].ops ], ops )
427
+
428
+
397
429
class TestDistLookupTable (TestDistLookupTableBase ):
398
430
def net_conf (self ):
399
431
self .network_with_table (is_sparse = True , is_distributed = True )
0 commit comments