Skip to content

Commit a384249

Browse files
authored
Adding nesterov momentum to python momentum wrapper (#5055)
* Adding nesterov momentum to python momentum wrapper * Fixing optimizer test after merge
1 parent 0760043 commit a384249

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

python/paddle/v2/framework/optimizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,13 +211,14 @@ class MomentumOptimizer(Optimizer):
211211
"""
212212
_velocity_acc_str = "velocity"
213213

214-
def __init__(self, learning_rate, momentum):
214+
def __init__(self, learning_rate, momentum, use_nesterov=False):
215215
assert learning_rate is not None
216216
assert momentum is not None
217217
super(MomentumOptimizer, self).__init__()
218218
self.type = "momentum"
219219
self._learning_rate = learning_rate
220220
self._momentum = momentum
221+
self._use_nesterov = bool(use_nesterov)
221222

222223
def _initialize_tensors(self, block):
223224
assert isinstance(block, framework.Block)
@@ -259,7 +260,8 @@ def _append_optimize_op(self, block, param_and_grad):
259260
"ParamOut": param_and_grad[0],
260261
"VelocityOut": velocity_acc
261262
},
262-
attrs={"mu": self._momentum})
263+
attrs={"mu": self._momentum,
264+
"useNesterov": self._use_nesterov})
263265

264266
return momentum_op
265267

python/paddle/v2/framework/tests/test_optimizer.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def get_accumulators(self):
3636
def get_velocity_str(self):
3737
return self._velocity_acc_str
3838

39-
def test_momentum_optimizer(self):
39+
def test_vanilla_momentum_optimizer(self):
4040
program = framework.Program()
4141
block = program.global_block()
4242
mul_x = block.create_parameter(
@@ -60,6 +60,42 @@ def test_momentum_optimizer(self):
6060
self.assertEqual(len(opts), 1)
6161
sgd_op = opts[0]
6262
self.assertEqual(sgd_op.type, "momentum")
63+
self.assertFalse(sgd_op.attr('useNesterov'))
64+
65+
# Check accumulators
66+
accumulators = momentum_optimizer.get_accumulators()
67+
self.assertEqual(len(accumulators), 1)
68+
self.assertTrue(momentum_optimizer.get_velocity_str() in accumulators)
69+
velocity_acc = accumulators[momentum_optimizer.get_velocity_str()]
70+
self.assertEqual(len(velocity_acc), 1)
71+
self.assertTrue(mul_x.name in velocity_acc)
72+
73+
def test_nesterov_momentum_optimizer(self):
74+
program = framework.Program()
75+
block = program.global_block()
76+
mul_x = block.create_parameter(
77+
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
78+
mul_y = block.create_var(
79+
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
80+
mul_out = block.create_var(
81+
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
82+
block.append_op(
83+
type="mul",
84+
inputs={"X": mul_x,
85+
"Y": mul_y},
86+
outputs={"Out": mul_out},
87+
attrs={"x_num_col_dims": 1})
88+
momentum_optimizer = self.MockMomentum(
89+
learning_rate=0.01, momentum=0.2, use_nesterov=True)
90+
params_grads = append_backward_ops(mul_out)
91+
self.assertEqual(len(params_grads), 1)
92+
self.assertEqual(len(momentum_optimizer.get_accumulators()), 0)
93+
opts = momentum_optimizer.create_optimization_pass(params_grads,
94+
mul_out)
95+
self.assertEqual(len(opts), 1)
96+
sgd_op = opts[0]
97+
self.assertEqual(sgd_op.type, "momentum")
98+
self.assertTrue(sgd_op.attr('useNesterov'))
6399

64100
# Check accumulators
65101
accumulators = momentum_optimizer.get_accumulators()

0 commit comments

Comments
 (0)