@@ -198,7 +198,7 @@ def test_adagrad_optimizer(self):
198
198
adagrad_op = opts [0 ]
199
199
self .assertEqual (adagrad_op .type , "adagrad" )
200
200
201
- # check accumulators
201
+ # Check accumulators
202
202
accumulators = adagrad_optimizer .get_accumulators ()
203
203
self .assertEqual (len (accumulators ), 1 )
204
204
self .assertTrue (adagrad_optimizer .get_moment_str () in accumulators )
@@ -331,5 +331,59 @@ def test_adamax_optimizer(self):
331
331
self .assertAlmostEqual (init_ops [0 ].attr ('value' ), learning_rate )
332
332
333
333
334
+ class TestDecayedAdagradOptimizer (unittest .TestCase ):
335
+ class MockDecayedAdagrad (optimizer .DecayedAdagradOptimizer ):
336
+ def get_accumulators (self ):
337
+ return self ._accumulators
338
+
339
+ def get_moment_str (self ):
340
+ return self ._moment_acc_str
341
+
342
+ def test_decayed_adagrad_optimizer (self ):
343
+ init_program = framework .Program ()
344
+ program = framework .Program ()
345
+ block = program .global_block ()
346
+ mul_x = block .create_parameter (
347
+ dtype = "float32" , shape = [5 , 10 ], lod_level = 0 , name = "mul.x" )
348
+ mul_y = block .create_var (
349
+ dtype = "float32" , shape = [10 , 8 ], lod_level = 0 , name = "mul.y" )
350
+ mul_out = block .create_var (
351
+ dtype = "float32" , shape = [5 , 8 ], lod_level = 0 , name = "mul.out" )
352
+ block .append_op (
353
+ type = "mul" ,
354
+ inputs = {"X" : mul_x ,
355
+ "Y" : mul_y },
356
+ outputs = {"Out" : mul_out },
357
+ attrs = {"x_num_col_dims" : 1 })
358
+ learning_rate = 0.01
359
+ decayed_adagrad_optimizer = self .MockDecayedAdagrad (
360
+ learning_rate = learning_rate , decay = 0.95 , epsilon = 1.0e-6 )
361
+ params_grads = append_backward_ops (mul_out )
362
+ self .assertEqual (len (params_grads ), 1 )
363
+ self .assertEqual (len (decayed_adagrad_optimizer .get_accumulators ()), 0 )
364
+ opts = decayed_adagrad_optimizer .create_optimization_pass (
365
+ params_grads , mul_out , init_program )
366
+ self .assertEqual (len (opts ), 1 )
367
+ decayed_adagrad_op = opts [0 ]
368
+ self .assertEqual (decayed_adagrad_op .type , "decayed_adagrad" )
369
+
370
+ # Check accumulators
371
+ accumulators = decayed_adagrad_optimizer .get_accumulators ()
372
+ self .assertEqual (len (accumulators ), 1 )
373
+ self .assertTrue (
374
+ decayed_adagrad_optimizer .get_moment_str () in accumulators )
375
+ moment_acc = accumulators [decayed_adagrad_optimizer .get_moment_str ()]
376
+ self .assertEqual (len (moment_acc ), 1 )
377
+ self .assertTrue (mul_x .name in moment_acc )
378
+
379
+ # Check init_program
380
+ init_ops = init_program .global_block ().ops
381
+ self .assertEqual (len (init_ops ), 2 )
382
+ self .assertEqual (init_ops [0 ].type , "fill_constant" )
383
+ self .assertAlmostEqual (init_ops [0 ].attr ('value' ), learning_rate )
384
+ self .assertEqual (init_ops [1 ].type , "fill_constant" )
385
+ self .assertAlmostEqual (init_ops [1 ].attr ('value' ), 0.0 )
386
+
387
+
334
388
if __name__ == '__main__' :
335
389
unittest .main ()
0 commit comments