@@ -434,5 +434,71 @@ def test_decayed_adagrad_optimizer(self):
434
434
self .assertAlmostEqual (init_ops [1 ].attr ('value' ), 0.0 )
435
435
436
436
437
+ class TestFtrlOptimizer (unittest .TestCase ):
438
+ class MockFtrl (optimizer .FtrlOptimizer ):
439
+ def get_accumulators (self ):
440
+ return self ._accumulators
441
+
442
+ def get_squared_str (self ):
443
+ return self ._squared_acc_str
444
+
445
+ def get_linear_str (self ):
446
+ return self ._linear_acc_str
447
+
448
+ def test_ftrl_optimizer (self ):
449
+ init_program = framework .Program ()
450
+ program = framework .Program ()
451
+ block = program .global_block ()
452
+ mul_x = block .create_parameter (
453
+ dtype = "float32" ,
454
+ shape = [5 , 10 ],
455
+ lod_level = 0 ,
456
+ name = "mul.x" ,
457
+ optimize_attr = {'learning_rate' : 1.1 })
458
+ mul_y = block .create_var (
459
+ dtype = "float32" , shape = [10 , 8 ], lod_level = 0 , name = "mul.y" )
460
+ mul_out = block .create_var (
461
+ dtype = "float32" , shape = [5 , 8 ], lod_level = 0 , name = "mul.out" )
462
+ block .append_op (
463
+ type = "mul" ,
464
+ inputs = {"X" : mul_x ,
465
+ "Y" : mul_y },
466
+ outputs = {"Out" : mul_out },
467
+ attrs = {"x_num_col_dims" : 1 })
468
+ mean_out = block .create_var (
469
+ dtype = "float32" , shape = [1 ], lod_level = 0 , name = "mean.out" )
470
+ block .append_op (
471
+ type = "mean" , inputs = {"X" : mul_out }, outputs = {"Out" : mean_out })
472
+ learning_rate = 0.01
473
+ ftrl_optimizer = self .MockFtrl (
474
+ learning_rate = learning_rate , l1 = 0.0 , l2 = 0.0 , lr_power = - 0.5 )
475
+ params_grads = append_backward (mean_out )
476
+ self .assertEqual (len (params_grads ), 1 )
477
+ self .assertEqual (len (ftrl_optimizer .get_accumulators ()), 0 )
478
+ opts = ftrl_optimizer .create_optimization_pass (params_grads , mul_out ,
479
+ init_program )
480
+ self .assertEqual (len (opts ), 3 )
481
+ self .assertEqual ([op .type for op in opts ],
482
+ ["fill_constant" , "elementwise_mul" , "ftrl" ])
483
+
484
+ # Check accumulators
485
+ accumulators = ftrl_optimizer .get_accumulators ()
486
+ self .assertEqual (len (accumulators ), 2 )
487
+ self .assertTrue (ftrl_optimizer .get_squared_str () in accumulators )
488
+ self .assertTrue (ftrl_optimizer .get_linear_str () in accumulators )
489
+ squared_acc = accumulators [ftrl_optimizer .get_squared_str ()]
490
+ linear_acc = accumulators [ftrl_optimizer .get_linear_str ()]
491
+ self .assertEqual (len (squared_acc ), 1 )
492
+ self .assertEqual (len (linear_acc ), 1 )
493
+ self .assertTrue (mul_x .name in squared_acc )
494
+ self .assertTrue (mul_x .name in linear_acc )
495
+
496
+ # Check init_program
497
+ init_ops = init_program .global_block ().ops
498
+ self .assertEqual (len (init_ops ), 3 )
499
+ self .assertEqual (init_ops [0 ].type , "fill_constant" )
500
+ self .assertAlmostEqual (init_ops [0 ].attr ('value' ), learning_rate )
501
+
502
+
437
503
if __name__ == '__main__' :
438
504
unittest .main ()
0 commit comments