Skip to content

Commit dd0008d

Browse files
authored
Extract apply_backward_pass to backward.py (#5026)
* Extract apply_backward_pass to backward.py Rename apply_backward_pass to append_backward_ops * Fix CI * Update design doc
1 parent fd2eb55 commit dd0008d

File tree

4 files changed

+56
-61
lines changed

4 files changed

+56
-61
lines changed

doc/design/optimizer.md

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,6 @@ class Optimizer(object):
6565
def __init__(self):
6666
pass
6767

68-
def create_backward_pass(self, loss, parameter_list=None):
69-
"""
70-
create and add gradient Operators in BlockDesc to Compute gradients of `loss`
71-
for parameters in parameter_list
72-
73-
Args:
74-
loss: an variable generated by cost function.
75-
parameter_list: parameters that need to compute gradient and update to optimize the lost.
76-
77-
Returns:
78-
list of (parameters, gradients) pair.
79-
"""
80-
return None
81-
8268
def create_optimization_pass(self, parameters_and_grads):
8369
"""Add optimization operators to update gradients to variables.
8470
@@ -93,7 +79,7 @@ class Optimizer(object):
9379
def minimize(self, loss, parameter_list):
9480
"""Add operations to minimize `loss` by updating `parameter_list`.
9581
96-
This method combines interface `create_backward_pass()` and
82+
This method combines interface `append_backward_ops()` and
9783
`create_optimization_pass()` into one.
9884
"""
9985
params_grads = self.create_backward_pass(loss, parameter_list)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from paddle.v2.framework import framework as framework
2+
3+
__all__ = ['append_backward_ops']
4+
5+
6+
def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
7+
"""
8+
Create and add gradient Operators in BlockDesc to compute
9+
gradients of `loss` for parameters in parameter_list
10+
11+
:param loss: an variable generated by cost function.
12+
:type loss: Variable
13+
:param no_grad_set: variable that should not create gradient
14+
:type no_grad_set: set
15+
:param parameter_list: parameters that need to compute gradient and
16+
update to optimize the lost.
17+
:type: list
18+
:return: list of (parameters, gradients) pair.
19+
:rtype: list[Variable]
20+
"""
21+
assert isinstance(loss, framework.Variable)
22+
param_grad_map = loss.block.program.append_backward(loss, no_grad_set or
23+
set())
24+
if parameter_list is not None:
25+
parameters = parameter_list
26+
else:
27+
params = loss.block.program.global_block().all_parameters()
28+
parameters = [param.name for param in params]
29+
params_and_grads = []
30+
for param in parameters:
31+
if param not in param_grad_map:
32+
raise ValueError("param %s is not in map" % param)
33+
grad_info = param_grad_map[param]
34+
grad_block = loss.block.program.block(grad_info[1])
35+
if not grad_block.has_var(grad_info[0]):
36+
raise ValueError("grad block[{0}] did not have grad var {1}".format(
37+
grad_info[1], grad_info[0]))
38+
# Get the param var from the global block
39+
param_var = loss.block.program.global_block().var(param)
40+
grad_var = grad_block.var(grad_info[0])
41+
if loss.block.has_var(grad_info[0]):
42+
params_and_grads.append((param_var, grad_var))
43+
else:
44+
params_and_grads.append((param_var, None))
45+
return params_and_grads

python/paddle/v2/framework/optimizer.py

Lines changed: 6 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
import paddle.v2.framework.framework as framework
21
from collections import defaultdict
32

3+
import paddle.v2.framework.framework as framework
4+
from paddle.v2.framework.backward import append_backward_ops
5+
46
__all__ = [
57
'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer'
68
]
@@ -105,45 +107,6 @@ def _get_accumulator(self, name, param):
105107
format(name, param.name))
106108
return self._accumulators[name][param.name]
107109

108-
def create_backward_pass(self, loss, parameter_list=None, no_grad_set=None):
109-
"""Create and add gradient Operators in BlockDesc to compute
110-
gradients of `loss` for parameters in parameter_list
111-
112-
Args:
113-
loss: an variable generated by cost function.
114-
no_grad_set: variable that should not create gradient
115-
parameter_list: parameters that need to compute gradient and
116-
update to optimize the lost.
117-
118-
Returns:
119-
list of (parameters, gradients) pair.
120-
"""
121-
assert isinstance(loss, framework.Variable)
122-
param_grad_map = loss.block.program.append_backward(loss, no_grad_set or
123-
set())
124-
if parameter_list is not None:
125-
parameters = parameter_list
126-
else:
127-
params = loss.block.program.global_block().all_parameters()
128-
parameters = [param.name for param in params]
129-
params_and_grads = []
130-
for param in parameters:
131-
if param not in param_grad_map:
132-
raise Exception("param %s is not in map" % param)
133-
grad_info = param_grad_map[param]
134-
grad_block = loss.block.program.block(grad_info[1])
135-
if not grad_block.has_var(grad_info[0]):
136-
raise Exception("grad block[%d] did not have grad var %s" %
137-
grad_info[1], grad_info[0])
138-
# Get the param var from the global block
139-
param_var = loss.block.program.global_block().var(param)
140-
grad_var = grad_block.var(grad_info[0])
141-
if loss.block.has_var(grad_info[0]):
142-
params_and_grads.append((param_var, grad_var))
143-
else:
144-
params_and_grads.append((param_var, None))
145-
return params_and_grads
146-
147110
def create_optimization_pass(self, parameters_and_grads, loss):
148111
"""Add optimization operators to update gradients to variables.
149112
@@ -192,11 +155,11 @@ def create_optimization_pass(self, parameters_and_grads, loss):
192155
def minimize(self, loss, parameter_list=None, no_grad_set=None):
193156
"""Add operations to minimize `loss` by updating `parameter_list`.
194157
195-
This method combines interface `create_backward_pass()` and
158+
This method combines interface `append_backward_ops()` and
196159
`create_optimization_pass()` into one.
197160
"""
198-
params_grads = self.create_backward_pass(loss, parameter_list,
199-
no_grad_set or set())
161+
params_grads = append_backward_ops(loss, parameter_list, no_grad_set or
162+
set())
200163
optimize_ops = self.create_optimization_pass(params_grads, loss)
201164
return optimize_ops
202165

python/paddle/v2/framework/tests/test_optimizer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import paddle.v2.framework.framework as framework
44
import paddle.v2.framework.optimizer as optimizer
5+
from paddle.v2.framework.backward import append_backward_ops
56

67

78
class TestOptimizer(unittest.TestCase):
@@ -51,7 +52,7 @@ def test_momentum_optimizer(self):
5152
outputs={"Out": mul_out},
5253
attrs={"x_num_col_dims": 1})
5354
momentum_optimizer = self.MockMomentum(learning_rate=0.01, momentum=0.2)
54-
params_grads = momentum_optimizer.create_backward_pass(mul_out)
55+
params_grads = append_backward_ops(mul_out)
5556
self.assertEqual(len(params_grads), 1)
5657
self.assertEqual(len(momentum_optimizer.get_accumulators()), 0)
5758
opts = momentum_optimizer.create_optimization_pass(params_grads,
@@ -93,7 +94,7 @@ def test_adagrad_optimizer(self):
9394
outputs={"Out": mul_out},
9495
attrs={"x_num_col_dims": 1})
9596
adagrad_optimizer = self.MockAdagrad(learning_rate=0.01, epsilon=1.0e-6)
96-
params_grads = adagrad_optimizer.create_backward_pass(mul_out)
97+
params_grads = append_backward_ops(mul_out)
9798
self.assertEqual(len(params_grads), 1)
9899
self.assertEqual(len(adagrad_optimizer.get_accumulators()), 0)
99100
opts = adagrad_optimizer.create_optimization_pass(params_grads, mul_out)
@@ -138,7 +139,7 @@ def test_adam_optimizer(self):
138139
attrs={"x_num_col_dims": 1})
139140
adam_optimizer = self.MockAdam(
140141
learning_rate=0.01, beta1=0.9, beta2=0.999)
141-
params_grads = adam_optimizer.create_backward_pass(mul_out)
142+
params_grads = append_backward_ops(mul_out)
142143
self.assertEqual(len(params_grads), 1)
143144
self.assertEqual(len(adam_optimizer.get_accumulators()), 0)
144145
opts = adam_optimizer.create_optimization_pass(params_grads, mul_out)

0 commit comments

Comments
 (0)