Skip to content

Commit d78d119

Browse files
authored
Adding python wrapper for adam operator (#5021)
* Adding Adam Python wrapper * Adding tests for Python Adam wrapper
1 parent 046b815 commit d78d119

File tree

2 files changed

+202
-5
lines changed

2 files changed

+202
-5
lines changed

python/paddle/v2/framework/optimizer.py

Lines changed: 153 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import paddle.v2.framework.framework as framework
22
from collections import defaultdict
33

4-
__all__ = ['SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer']
4+
__all__ = [
5+
'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer'
6+
]
57

68

79
class Optimizer(object):
@@ -43,6 +45,19 @@ def _create_accumulators(self, block, parameters):
4345
"""
4446
pass
4547

48+
def _finish_update(self, block):
49+
"""Finish any custom updates needed
50+
before completing an optimization step
51+
52+
Args:
53+
block: the block in which the loss variable is present
54+
parameters: list of parameter variables for the optimizer
55+
56+
Returns:
57+
list of finish ops or None
58+
"""
59+
pass
60+
4661
def _add_accumulator(self, block, name, param, dtype=None, fill_value=0.0):
4762
"""Utility function to add an accumulator for a parameter
4863
@@ -137,15 +152,17 @@ def create_optimization_pass(self, parameters_and_grads, loss):
137152
parameters_and_grads: a list of (variable, gradient) pair to update.
138153
139154
Returns:
140-
optmization_op_list: a list of optimization operator that will update
141-
parameter using gradient.
155+
return_op_list: a list of operators that will complete one step of
156+
optimization. This will include parameter update ops, global step
157+
update ops and any other custom ops required by subclasses to manage
158+
their internal state.
142159
"""
143160
# This is a default implementation of create_optimization_pass that
144161
# can be shared by most optimizers. This implementation assumes that
145162
# the subclass will implement the _append_optimize_op method and the
146163
# _initialize_tensors method. The subclass can extend the
147164
# _create_accumulators method if it needs to create accumulators
148-
# for parameters.
165+
# for parameters and extend _finish_update method to add custom ops.
149166

150167
# Create any accumulators
151168
self._create_accumulators(loss.block,
@@ -160,7 +177,17 @@ def create_optimization_pass(self, parameters_and_grads, loss):
160177
param_and_grad)
161178
optimize_ops.append(optimize_op)
162179

163-
return optimize_ops
180+
# Returned list of ops can include more ops in addition
181+
# to optimization ops
182+
return_ops = optimize_ops
183+
184+
# Get custom finish ops for subclasses
185+
# FIXME: Need to fix this once we figure out how to handle dependencies
186+
finish_ops = self._finish_update(loss.block)
187+
if finish_ops is not None:
188+
return_ops += finish_ops
189+
190+
return return_ops
164191

165192
def minimize(self, loss, parameter_list=None, no_grad_set=None):
166193
"""Add operations to minimize `loss` by updating `parameter_list`.
@@ -329,3 +356,124 @@ def _append_optimize_op(self, block, param_and_grad):
329356
attrs={"epsilon": self._epsilon})
330357

331358
return adagrad_op
359+
360+
361+
class AdamOptimizer(Optimizer):
362+
"""Implements the Adam Optimizer
363+
"""
364+
_moment1_acc_str = "moment1"
365+
_moment2_acc_str = "moment2"
366+
367+
def __init__(self,
368+
learning_rate=0.001,
369+
beta1=0.9,
370+
beta2=0.999,
371+
epsilon=1e-8):
372+
assert learning_rate is not None
373+
assert beta1 is not None
374+
assert beta2 is not None
375+
assert epsilon is not None
376+
super(AdamOptimizer, self).__init__()
377+
self.type = "adam"
378+
self._learning_rate = learning_rate
379+
self._beta1 = beta1
380+
self._beta2 = beta2
381+
self._epsilon = epsilon
382+
383+
def _initialize_tensors(self, block):
384+
assert isinstance(block, framework.Block)
385+
lr_shape = [1]
386+
# create a variable for learning_rate
387+
self._lr = block.create_var(
388+
dtype="float32", shape=lr_shape, lod_level=0)
389+
390+
# create an op to init the learning_rate
391+
# FIXME: Fix when Initialization design has been implemented
392+
# https://github.com/PaddlePaddle/Paddle/pull/4852
393+
block.append_op(
394+
type="fill_constant",
395+
outputs={"Out": self._lr},
396+
attrs={"shape": lr_shape,
397+
"value": self._learning_rate})
398+
399+
def _create_accumulators(self, block, parameters):
400+
assert isinstance(block, framework.Block)
401+
402+
global_block = block.program.global_block()
403+
# Create beta1 and beta2 power tensors
404+
beta_shape = [1]
405+
# Create variables for beta1 and beta2 powers
406+
self._beta1_pow_acc = global_block.create_var(
407+
dtype="float32", shape=beta_shape, lod_level=0)
408+
self._beta2_pow_acc = global_block.create_var(
409+
dtype="float32", shape=beta_shape, lod_level=0)
410+
411+
# Initialize beta1 and beta2 power accumulators
412+
# FIXME: Fix when Initialization design has been implemented
413+
# https://github.com/PaddlePaddle/Paddle/pull/4852
414+
global_block.append_op(
415+
type="fill_constant",
416+
outputs={"Out": self._beta1_pow_acc},
417+
attrs={"shape": beta_shape,
418+
"value": self._beta1})
419+
global_block.append_op(
420+
type="fill_constant",
421+
outputs={"Out": self._beta2_pow_acc},
422+
attrs={"shape": beta_shape,
423+
"value": self._beta2})
424+
425+
# Create accumulator tensors for first and second moments
426+
for p in parameters:
427+
self._add_accumulator(block, self._moment1_acc_str, p, 'float32')
428+
self._add_accumulator(block, self._moment2_acc_str, p, 'float32')
429+
430+
def _append_optimize_op(self, block, param_and_grad):
431+
assert isinstance(block, framework.Block)
432+
433+
moment1 = self._get_accumulator(self._moment1_acc_str,
434+
param_and_grad[0])
435+
moment2 = self._get_accumulator(self._moment2_acc_str,
436+
param_and_grad[0])
437+
# create the momentum optimize op
438+
adam_op = block.append_op(
439+
type=self.type,
440+
inputs={
441+
"Param": param_and_grad[0],
442+
"Grad": param_and_grad[1],
443+
"LearningRate": self._lr,
444+
"Moment1": moment1,
445+
"Moment2": moment2,
446+
"Beta1Pow": self._beta1_pow_acc,
447+
"Beta2Pow": self._beta2_pow_acc
448+
},
449+
outputs={
450+
"ParamOut": param_and_grad[0],
451+
"Moment1Out": moment1,
452+
"Moment2Out": moment2
453+
},
454+
attrs={
455+
"beta1": self._beta1,
456+
"beta2": self._beta2,
457+
"epsilon": self._epsilon
458+
})
459+
460+
return adam_op
461+
462+
def _finish_update(self, block):
463+
"""Update Beta1 and Beta2 Power accumulators
464+
"""
465+
assert isinstance(block, framework.Block)
466+
global_block = block.program.global_block()
467+
scale_beta1 = global_block.append_op(
468+
type="scale",
469+
inputs={"X": self._beta1_pow_acc},
470+
outputs={"Out": self._beta1_pow_acc},
471+
attrs={"scale": self._beta1})
472+
473+
scale_beta2 = global_block.append_op(
474+
type="scale",
475+
inputs={"X": self._beta2_pow_acc},
476+
outputs={"Out": self._beta2_pow_acc},
477+
attrs={"scale": self._beta2})
478+
479+
return [scale_beta1, scale_beta2]

python/paddle/v2/framework/tests/test_optimizer.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,5 +110,54 @@ def test_adagrad_optimizer(self):
110110
self.assertTrue(mul_x.name in moment_acc)
111111

112112

113+
class TestAdamOptimizer(unittest.TestCase):
114+
class MockAdam(optimizer.AdamOptimizer):
115+
def get_accumulators(self):
116+
return self._accumulators
117+
118+
def get_moment1_str(self):
119+
return self._moment1_acc_str
120+
121+
def get_moment2_str(self):
122+
return self._moment2_acc_str
123+
124+
def test_adam_optimizer(self):
125+
program = framework.Program()
126+
block = program.global_block()
127+
mul_x = block.create_parameter(
128+
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
129+
mul_y = block.create_var(
130+
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
131+
mul_out = block.create_var(
132+
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
133+
block.append_op(
134+
type="mul",
135+
inputs={"X": mul_x,
136+
"Y": mul_y},
137+
outputs={"Out": mul_out},
138+
attrs={"x_num_col_dims": 1})
139+
adam_optimizer = self.MockAdam(
140+
learning_rate=0.01, beta1=0.9, beta2=0.999)
141+
params_grads = adam_optimizer.create_backward_pass(mul_out)
142+
self.assertEqual(len(params_grads), 1)
143+
self.assertEqual(len(adam_optimizer.get_accumulators()), 0)
144+
opts = adam_optimizer.create_optimization_pass(params_grads, mul_out)
145+
self.assertEqual(len(opts), 3)
146+
adam_op = opts[0]
147+
self.assertEqual(adam_op.type, "adam")
148+
149+
# Check accumulators
150+
accumulators = adam_optimizer.get_accumulators()
151+
self.assertEqual(len(accumulators), 2)
152+
self.assertTrue(adam_optimizer.get_moment1_str() in accumulators)
153+
self.assertTrue(adam_optimizer.get_moment2_str() in accumulators)
154+
moment1_acc = accumulators[adam_optimizer.get_moment1_str()]
155+
moment2_acc = accumulators[adam_optimizer.get_moment2_str()]
156+
self.assertEqual(len(moment1_acc), 1)
157+
self.assertEqual(len(moment2_acc), 1)
158+
self.assertTrue(mul_x.name in moment1_acc)
159+
self.assertTrue(mul_x.name in moment2_acc)
160+
161+
113162
if __name__ == '__main__':
114163
unittest.main()

0 commit comments

Comments
 (0)