@@ -359,5 +359,110 @@ def transpiler_test_impl(self):
359
359
["sum" , "scale" , "scale" , "elementwise_add" , "momentum" ])
360
360
361
361
362
+ class TestDistLookupTableBase (TranspilerTest ):
363
+ def network_with_table (self , is_sparse , is_distributed ):
364
+ def emb_pool (ids ):
365
+ table_size = 1000
366
+ emb_size = 64
367
+ emb = fluid .layers .embedding (
368
+ input = ids ,
369
+ size = [table_size , emb_size ],
370
+ dtype = 'float32' ,
371
+ param_attr = 'shared_w' , # share parameter
372
+ is_sparse = is_sparse ,
373
+ is_distributed = is_distributed )
374
+ pool = fluid .layers .sequence_pool (input = emb , pool_type = 'average' )
375
+ return pool
376
+
377
+ title_ids = fluid .layers .data (
378
+ name = 'title_ids' , shape = [1 ], dtype = 'int64' , lod_level = 1 )
379
+ brand_ids = fluid .layers .data (
380
+ name = 'brand_ids' , shape = [1 ], dtype = 'int64' , lod_level = 1 )
381
+ title_emb = emb_pool (title_ids )
382
+ brand_emb = emb_pool (brand_ids )
383
+ fc0 = fluid .layers .concat (input = [title_emb , brand_emb ], axis = 1 )
384
+ predict = fluid .layers .fc (input = fc0 ,
385
+ size = 2 ,
386
+ act = None ,
387
+ param_attr = fluid .ParamAttr (name = 'fc_w' ),
388
+ bias_attr = fluid .ParamAttr (name = 'fc_b' ))
389
+
390
+ label = fluid .layers .data (name = 'label' , shape = [1 ], dtype = 'int64' )
391
+ cost = fluid .layers .cross_entropy (input = predict , label = label )
392
+ avg_cost = fluid .layers .mean (cost )
393
+ optimizer = fluid .optimizer .Adam (learning_rate = 0.003 )
394
+ optimizer .minimize (avg_cost )
395
+
396
+
397
+ class TestLocalLookupTable (TestDistLookupTableBase ):
398
+ def net_conf (self ):
399
+ self .network_with_table (is_sparse = True , is_distributed = False )
400
+
401
+ def transpiler_test_impl (self ):
402
+ pserver1 , startup1 = self .get_pserver (self .pserver1_ep )
403
+
404
+ self .assertEqual (len (pserver1 .blocks ), 3 )
405
+ # 0 listen_and_serv
406
+ # 1 optimize for fc_w or fc_b adam
407
+ self .assertEqual ([op .type for op in pserver1 .blocks [1 ].ops ],
408
+ ["sum" , "scale" , "adam" , "scale" , "scale" ])
409
+ # 2 optimize for table adam
410
+ # NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
411
+ self .assertEqual ([op .type for op in pserver1 .blocks [2 ].ops ],
412
+ ["sum" , "adam" , "scale" , "scale" ])
413
+
414
+ trainer = self .get_trainer ()
415
+ self .assertEqual (len (trainer .blocks ), 1 )
416
+ ops = [
417
+ 'lookup_table' , 'sequence_pool' , 'lookup_table' , 'sequence_pool' ,
418
+ 'concat' , 'mul' , 'elementwise_add' , 'cross_entropy' , 'mean' ,
419
+ 'fill_constant' , 'mean_grad' , 'cross_entropy_grad' ,
420
+ 'elementwise_add_grad' , 'send' , 'mul_grad' , 'send' , 'concat_grad' ,
421
+ 'sequence_pool_grad' , 'lookup_table_grad' , 'sequence_pool_grad' ,
422
+ 'lookup_table_grad' , 'sum' , 'split_selected_rows' , 'send' ,
423
+ 'send_barrier' , 'recv' , 'recv' , 'recv' , 'fetch_barrier' , 'concat'
424
+ ]
425
+ self .assertEqual ([op .type for op in trainer .blocks [0 ].ops ], ops )
426
+
427
+
428
+ class TestDistLookupTable (TestDistLookupTableBase ):
429
+ def net_conf (self ):
430
+ self .network_with_table (is_sparse = True , is_distributed = True )
431
+
432
+ def transpiler_test_impl (self ):
433
+ pserver1 , startup1 = self .get_pserver (self .pserver1_ep )
434
+
435
+ self .assertEqual (len (pserver1 .blocks ), 6 )
436
+ # 0 listen_and_serv
437
+ # 1 optimize for fc_w or fc_b adam
438
+ self .assertEqual ([op .type for op in pserver1 .blocks [1 ].ops ],
439
+ ["sum" , "scale" , "adam" , "scale" , "scale" ])
440
+ # 2 optimize for table sgd
441
+ self .assertEqual ([op .type for op in pserver1 .blocks [2 ].ops ],
442
+ ["sum" , "sgd" ])
443
+ # 3 prefetch -> lookup_sparse_table for data0
444
+ self .assertEqual ([op .type for op in pserver1 .blocks [3 ].ops ],
445
+ ["lookup_sparse_table" ])
446
+ # 4 prefetch -> lookup_sparse_table for data1
447
+ self .assertEqual ([op .type for op in pserver1 .blocks [4 ].ops ],
448
+ ["lookup_sparse_table" ])
449
+ # 5 save table
450
+ self .assertEqual ([op .type for op in pserver1 .blocks [5 ].ops ], ["save" ])
451
+
452
+ trainer = self .get_trainer ()
453
+ self .assertEqual (len (trainer .blocks ), 1 )
454
+ ops = [
455
+ 'split_ids' , 'prefetch' , 'merge_ids' , 'sequence_pool' , 'split_ids' ,
456
+ 'prefetch' , 'merge_ids' , 'sequence_pool' , 'concat' , 'mul' ,
457
+ 'elementwise_add' , 'cross_entropy' , 'mean' , 'fill_constant' ,
458
+ 'mean_grad' , 'cross_entropy_grad' , 'elementwise_add_grad' , 'send' ,
459
+ 'mul_grad' , 'send' , 'concat_grad' , 'sequence_pool_grad' ,
460
+ 'lookup_table_grad' , 'sequence_pool_grad' , 'lookup_table_grad' ,
461
+ 'sum' , 'split_ids' , 'send' , 'send_barrier' , 'recv' , 'recv' ,
462
+ 'fetch_barrier'
463
+ ]
464
+ self .assertEqual ([op .type for op in trainer .blocks [0 ].ops ], ops )
465
+
466
+
362
467
if __name__ == "__main__" :
363
468
unittest .main ()
0 commit comments