@@ -36,7 +36,7 @@ def get_accumulators(self):
36
36
def get_velocity_str (self ):
37
37
return self ._velocity_acc_str
38
38
39
- def test_momentum_optimizer (self ):
39
+ def test_vanilla_momentum_optimizer (self ):
40
40
program = framework .Program ()
41
41
block = program .global_block ()
42
42
mul_x = block .create_parameter (
@@ -60,6 +60,42 @@ def test_momentum_optimizer(self):
60
60
self .assertEqual (len (opts ), 1 )
61
61
sgd_op = opts [0 ]
62
62
self .assertEqual (sgd_op .type , "momentum" )
63
+ self .assertFalse (sgd_op .attr ('useNesterov' ))
64
+
65
+ # Check accumulators
66
+ accumulators = momentum_optimizer .get_accumulators ()
67
+ self .assertEqual (len (accumulators ), 1 )
68
+ self .assertTrue (momentum_optimizer .get_velocity_str () in accumulators )
69
+ velocity_acc = accumulators [momentum_optimizer .get_velocity_str ()]
70
+ self .assertEqual (len (velocity_acc ), 1 )
71
+ self .assertTrue (mul_x .name in velocity_acc )
72
+
73
+ def test_nesterov_momentum_optimizer (self ):
74
+ program = framework .Program ()
75
+ block = program .global_block ()
76
+ mul_x = block .create_parameter (
77
+ dtype = "float32" , shape = [5 , 10 ], lod_level = 0 , name = "mul.x" )
78
+ mul_y = block .create_var (
79
+ dtype = "float32" , shape = [10 , 8 ], lod_level = 0 , name = "mul.y" )
80
+ mul_out = block .create_var (
81
+ dtype = "float32" , shape = [5 , 8 ], lod_level = 0 , name = "mul.out" )
82
+ block .append_op (
83
+ type = "mul" ,
84
+ inputs = {"X" : mul_x ,
85
+ "Y" : mul_y },
86
+ outputs = {"Out" : mul_out },
87
+ attrs = {"x_num_col_dims" : 1 })
88
+ momentum_optimizer = self .MockMomentum (
89
+ learning_rate = 0.01 , momentum = 0.2 , use_nesterov = True )
90
+ params_grads = append_backward_ops (mul_out )
91
+ self .assertEqual (len (params_grads ), 1 )
92
+ self .assertEqual (len (momentum_optimizer .get_accumulators ()), 0 )
93
+ opts = momentum_optimizer .create_optimization_pass (params_grads ,
94
+ mul_out )
95
+ self .assertEqual (len (opts ), 1 )
96
+ sgd_op = opts [0 ]
97
+ self .assertEqual (sgd_op .type , "momentum" )
98
+ self .assertTrue (sgd_op .attr ('useNesterov' ))
63
99
64
100
# Check accumulators
65
101
accumulators = momentum_optimizer .get_accumulators ()
0 commit comments