@@ -359,5 +359,79 @@ def transpiler_test_impl(self):
359
359
["sum" , "scale" , "scale" , "elementwise_add" , "momentum" ])
360
360
361
361
362
+ class TestDistLookupTableBase (TranspilerTest ):
363
+ def network_with_table (self , is_sparse , is_distributed ):
364
+ def emb_pool (ids ):
365
+ table_size = 1000
366
+ emb_size = 64
367
+ emb = fluid .layers .embedding (
368
+ input = ids ,
369
+ size = [table_size , emb_size ],
370
+ dtype = 'float32' ,
371
+ param_attr = 'shared_w' , # share parameter
372
+ is_sparse = is_sparse ,
373
+ is_distributed = is_distributed )
374
+ pool = fluid .layers .sequence_pool (input = emb , pool_type = 'average' )
375
+ return pool
376
+
377
+ title_ids = fluid .layers .data (
378
+ name = 'title_ids' , shape = [1 ], dtype = 'int64' , lod_level = 1 )
379
+ brand_ids = fluid .layers .data (
380
+ name = 'brand_ids' , shape = [1 ], dtype = 'int64' , lod_level = 1 )
381
+ title_emb = emb_pool (title_ids )
382
+ brand_emb = emb_pool (brand_ids )
383
+ fc0 = fluid .layers .concat (input = [title_emb , brand_emb ], axis = 1 )
384
+ predict = fluid .layers .fc (input = fc0 ,
385
+ size = 2 ,
386
+ act = None ,
387
+ param_attr = fluid .ParamAttr (name = 'fc_w' ),
388
+ bias_attr = fluid .ParamAttr (name = 'fc_b' ))
389
+
390
+ label = fluid .layers .data (name = 'label' , shape = [1 ], dtype = 'int64' )
391
+ cost = fluid .layers .cross_entropy (input = predict , label = label )
392
+ avg_cost = fluid .layers .mean (cost )
393
+ optimizer = fluid .optimizer .Adam (learning_rate = 0.003 )
394
+ optimizer .minimize (avg_cost )
395
+
396
+
397
+ class TestDistLookupTable (TestDistLookupTableBase ):
398
+ def net_conf (self ):
399
+ self .network_with_table (is_sparse = True , is_distributed = True )
400
+
401
+ def transpiler_test_impl (self ):
402
+ pserver1 , startup1 = self .get_pserver (self .pserver1_ep )
403
+
404
+ self .assertEqual (len (pserver1 .blocks ), 6 )
405
+ # 0 listen_and_serv
406
+ # 1 optimize for fc_w or fc_b adam
407
+ self .assertEqual ([op .type for op in pserver1 .blocks [1 ].ops ],
408
+ ["sum" , "scale" , "adam" , "scale" , "scale" ])
409
+ # 2 optimize for table sgd
410
+ self .assertEqual ([op .type for op in pserver1 .blocks [2 ].ops ],
411
+ ["sum" , "sgd" ])
412
+ # 3 prefetch -> lookup_sparse_table for data0
413
+ self .assertEqual ([op .type for op in pserver1 .blocks [3 ].ops ],
414
+ ["lookup_sparse_table" ])
415
+ # 4 prefetch -> lookup_sparse_table for data1
416
+ self .assertEqual ([op .type for op in pserver1 .blocks [4 ].ops ],
417
+ ["lookup_sparse_table" ])
418
+ # 5 save table
419
+ self .assertEqual ([op .type for op in pserver1 .blocks [5 ].ops ], ["save" ])
420
+
421
+ trainer = self .get_trainer ()
422
+ self .assertEqual (len (trainer .blocks ), 1 )
423
+ ops = [
424
+ 'split_ids' , 'prefetch' , 'merge_ids' , 'sequence_pool' , 'split_ids' ,
425
+ 'prefetch' , 'merge_ids' , 'sequence_pool' , 'concat' , 'mul' ,
426
+ 'elementwise_add' , 'cross_entropy' , 'mean' , 'fill_constant' ,
427
+ 'mean_grad' , 'cross_entropy_grad' , 'elementwise_add_grad' , 'send' ,
428
+ 'mul_grad' , 'send' , 'concat_grad' , 'sequence_pool_grad' ,
429
+ 'lookup_table_grad' , 'sequence_pool_grad' , 'lookup_table_grad' ,
430
+ 'sum' , 'split_ids' , 'send' , 'send_barrier' , 'recv' , 'recv' ,
431
+ 'fetch_barrier'
432
+ ]
433
+ self .assertEqual ([op .type for op in trainer .blocks [0 ].ops ], ops )
434
+
435
+
362
436
if __name__ == "__main__" :
363
437
unittest .main ()
0 commit comments