Skip to content

Commit dfa05da

Browse files
[cherry-pick] fuse L2Decay and momentum when param.regularizer is set (#32845) (#32881)
fuse L2Decay and momentum when param.regularizer is set cherry-pick #32845
1 parent 1cdf69b commit dfa05da

File tree

6 files changed

+288
-102
lines changed

6 files changed

+288
-102
lines changed

python/paddle/fluid/optimizer.py

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from .initializer import Constant
3434
from .layer_helper import LayerHelper
3535
from .layers import ops
36-
from .regularizer import append_regularization_ops
3736
from .dygraph import base as imperative_base
3837
from .dygraph import no_grad
3938
from .dygraph.learning_rate_scheduler import LearningRateDecay, _LearningRateEpochDecay
@@ -805,6 +804,93 @@ def backward(self,
805804
act_no_grad_set, callbacks)
806805
return params_grads
807806

807+
def _create_regularization_of_grad(self, param, grad, regularization=None):
808+
""" Create and add backward regularization Operators
809+
810+
Function helper of append_regularization_ops.
811+
"""
812+
# If no gradient or no regularization is specified, then we don't need to do anything
813+
if grad is None or ((not hasattr(param, 'regularizer') or
814+
(hasattr(param, 'regularizer') and
815+
param.regularizer is None)) and
816+
regularization is None):
817+
return grad
818+
regularization_term = None
819+
if hasattr(param, 'regularizer') and param.regularizer is not None:
820+
# Add variable for regularization term in grad block
821+
regularization_term = param.regularizer(param, grad, grad.block)
822+
elif regularization is not None:
823+
regularization_term = regularization(param, grad, grad.block)
824+
825+
assert regularization_term is not None
826+
827+
new_grad = grad
828+
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
829+
# FIXME(zcd): If the grad is SELECTED_ROWS, after regularization,
830+
# the grad's type and name will be changed. But the gradient's name
831+
# is used in ParallelExecutor Reduce mode, so I add a flag for
832+
# the new_grad here.
833+
new_grad = grad.block.create_var(
834+
name=grad.name + core.kNewGradSuffix(),
835+
dtype=param.dtype,
836+
shape=param.shape,
837+
lod_level=param.lod_level,
838+
type=core.VarDesc.VarType.LOD_TENSOR)
839+
840+
inputs = {"X": [grad, regularization_term]}
841+
outputs = {"Out": [new_grad]}
842+
if framework.in_dygraph_mode():
843+
new_grad = core.ops.sum([grad, regularization_term])
844+
else:
845+
grad.block.append_op(type='sum', inputs=inputs, outputs=outputs)
846+
847+
return new_grad
848+
849+
def append_regularization_ops(self,
850+
parameters_and_grads,
851+
regularization=None):
852+
r"""Create and add backward regularization Operators
853+
854+
Creates and adds backward regularization operators in the BlockDesc.
855+
This will add gradients of the regularizer function to the gradients
856+
of the parameters and return these modified gradients. This is the
857+
same as implementing weight decay in optimizers for regularization.
858+
859+
Args:
860+
parameters_and_grads: A list of (parameters, gradients) pairs
861+
that need to be regularized.
862+
regularization: A global regularizer. If the parameter is not
863+
set. It will be applied with regularizer.
864+
865+
Returns:
866+
list[(Variable, Variable)]: list of (parameters, gradients) \
867+
pair with the regularized gradient
868+
869+
Raises:
870+
Exception: Unknown regularization type
871+
"""
872+
params_and_grads = []
873+
if framework.in_dygraph_mode():
874+
for param, grad in parameters_and_grads:
875+
new_grad = self._create_regularization_of_grad(param, grad,
876+
regularization)
877+
params_and_grads.append((param, new_grad))
878+
else:
879+
repeate_regularizer = False
880+
with framework.name_scope('regularization'):
881+
for param, grad in parameters_and_grads:
882+
if not repeate_regularizer and param.regularizer is not None and regularization is not None:
883+
repeate_regularizer = True
884+
logging.info(
885+
"If regularizer of a Parameter has been set by 'fluid.ParamAttr' or 'fluid.WeightNormParamAttr' already. "
886+
"The Regularization[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!"
887+
% regularization.__str__())
888+
with param.block.program._optimized_guard([param, grad]):
889+
new_grad = self._create_regularization_of_grad(
890+
param, grad, regularization)
891+
params_and_grads.append((param, new_grad))
892+
return params_and_grads
893+
808894
def apply_gradients(self, params_grads):
809895
"""
810896
Second part of `minimize`, appending optimization operators for
@@ -837,8 +923,8 @@ def apply_gradients(self, params_grads):
837923
params_grads = append_gradient_clip_ops(params_grads)
838924

839925
# Add regularization if any
840-
params_grads = append_regularization_ops(params_grads,
841-
self.regularization)
926+
params_grads = self.append_regularization_ops(params_grads,
927+
self.regularization)
842928

843929
optimize_ops = self._create_optimization_pass(params_grads)
844930
return optimize_ops
@@ -860,8 +946,8 @@ def apply_optimize(self, loss, startup_program, params_grads):
860946
framework.default_startup_program()):
861947
if self._grad_clip is not None:
862948
params_grads = self._grad_clip(params_grads)
863-
params_grads = append_regularization_ops(params_grads,
864-
self.regularization)
949+
params_grads = self.append_regularization_ops(
950+
params_grads, self.regularization)
865951
optimize_ops = self._create_optimization_pass(params_grads)
866952
else:
867953
program = loss.block.program
@@ -1595,8 +1681,8 @@ def apply_gradients(self, params_grads):
15951681
not_dgc_params_grads = append_gradient_clip_ops(
15961682
not_dgc_params_grads)
15971683

1598-
not_dgc_params_grads = append_regularization_ops(not_dgc_params_grads,
1599-
self.regularization)
1684+
not_dgc_params_grads = self.append_regularization_ops(
1685+
not_dgc_params_grads, self.regularization)
16001686

16011687
params_grads = not_dgc_params_grads + dgc_params_grads
16021688
params_grads = sorted(params_grads, key=lambda x: x[0].name)

python/paddle/fluid/regularizer.py

Lines changed: 0 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -22,92 +22,6 @@
2222
__all__ = ['L1Decay', 'L2Decay', 'L1DecayRegularizer', 'L2DecayRegularizer']
2323

2424

25-
def _create_regularization_of_grad(param, grad, regularization=None):
26-
""" Create and add backward regularization Operators
27-
28-
Function helper of append_regularization_ops.
29-
"""
30-
# If no gradient or no regularization is specified, then we don't need to do anything
31-
if grad is None or ((not hasattr(param, 'regularizer') or (
32-
hasattr(param, 'regularizer') and param.regularizer is None)) and
33-
regularization is None):
34-
return grad
35-
regularization_term = None
36-
if hasattr(param, 'regularizer') and param.regularizer is not None:
37-
# Add variable for regularization term in grad block
38-
regularization_term = param.regularizer(param, grad, grad.block)
39-
elif regularization is not None:
40-
regularization_term = regularization(param, grad, grad.block)
41-
42-
assert regularization_term is not None
43-
44-
new_grad = grad
45-
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
46-
# FIXME(zcd): If the grad is SELECTED_ROWS, after regularization,
47-
# the grad's type and name will be changed. But the gradient's name
48-
# is used in ParallelExecutor Reduce mode, so I add a flag for
49-
# the new_grad here.
50-
new_grad = grad.block.create_var(
51-
name=grad.name + core.kNewGradSuffix(),
52-
dtype=param.dtype,
53-
shape=param.shape,
54-
lod_level=param.lod_level,
55-
type=core.VarDesc.VarType.LOD_TENSOR)
56-
57-
inputs = {"X": [grad, regularization_term]}
58-
outputs = {"Out": [new_grad]}
59-
if in_dygraph_mode():
60-
new_grad = core.ops.sum([grad, regularization_term])
61-
else:
62-
grad.block.append_op(type='sum', inputs=inputs, outputs=outputs)
63-
64-
return new_grad
65-
66-
67-
def append_regularization_ops(parameters_and_grads, regularization=None):
68-
r"""Create and add backward regularization Operators
69-
70-
Creates and adds backward regularization operators in the BlockDesc.
71-
This will add gradients of the regularizer function to the gradients
72-
of the parameters and return these modified gradients. This is the
73-
same as implementing weight decay in optimizers for regularization.
74-
75-
Args:
76-
parameters_and_grads: A list of (parameters, gradients) pairs
77-
that need to be regularized.
78-
regularization: A global regularizer. If the parameter is not
79-
set. It will be applied with regularizer.
80-
81-
Returns:
82-
list[(Variable, Variable)]: list of (parameters, gradients) \
83-
pair with the regularized gradient
84-
85-
Raises:
86-
Exception: Unknown regularization type
87-
"""
88-
params_and_grads = []
89-
if in_dygraph_mode():
90-
for param, grad in parameters_and_grads:
91-
new_grad = _create_regularization_of_grad(param, grad,
92-
regularization)
93-
params_and_grads.append((param, new_grad))
94-
else:
95-
repeate_regularizer = False
96-
with framework.name_scope('regularization'):
97-
for param, grad in parameters_and_grads:
98-
if not repeate_regularizer and param.regularizer is not None and regularization is not None:
99-
repeate_regularizer = True
100-
logging.info(
101-
"If regularizer of a Parameter has been set by 'fluid.ParamAttr' or 'fluid.WeightNormParamAttr' already. "
102-
"The Regularization[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!"
103-
% regularization.__str__())
104-
with param.block.program._optimized_guard([param, grad]):
105-
new_grad = _create_regularization_of_grad(param, grad,
106-
regularization)
107-
params_and_grads.append((param, new_grad))
108-
return params_and_grads
109-
110-
11125
class WeightDecayRegularizer(object):
11226
"""Base class for weight decay regularizers
11327

python/paddle/fluid/tests/unittests/test_momentum_op.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,77 @@ def test_momentum_static(self):
555555
exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)
556556

557557

558+
class TestFusedMomentumWithDecayAPI(unittest.TestCase):
559+
def get_program(self, weight_attr, bias_attr=False):
560+
main_program = paddle.static.Program()
561+
startup_program = paddle.static.Program()
562+
with paddle.static.program_guard(
563+
main_program=main_program, startup_program=startup_program):
564+
x = paddle.static.data(name='x', shape=[10, 10])
565+
linear = paddle.nn.Linear(
566+
10, 10, weight_attr=weight_attr, bias_attr=bias_attr)
567+
out = linear(x)
568+
loss = paddle.mean(out)
569+
optimizer = paddle.optimizer.Momentum(
570+
learning_rate=0.01,
571+
momentum=0.9,
572+
weight_decay=paddle.regularizer.L2Decay(0.5))
573+
optimizer.minimize(loss)
574+
return main_program
575+
576+
def test_param_has_l2decay(self):
577+
paddle.enable_static()
578+
weight_attr = paddle.ParamAttr(
579+
name="weight",
580+
initializer=paddle.nn.initializer.Constant(value=0.5),
581+
regularizer=paddle.regularizer.L2Decay(0.1))
582+
program = self.get_program(weight_attr, bias_attr=False)
583+
ops = program.global_block().ops
584+
585+
self.assertEqual(ops[-1].attr('regularization_method'), 'l2_decay')
586+
self.assertEqual(ops[-1].attr('regularization_coeff'), np.float32(0.1))
587+
for i in range(len(ops)):
588+
self.assertTrue('sum' not in ops[i].type)
589+
self.assertTrue('scale' not in ops[i].type)
590+
591+
def test_param_has_l1decay(self):
592+
paddle.enable_static()
593+
weight_attr = paddle.ParamAttr(
594+
name="weight",
595+
initializer=paddle.nn.initializer.Constant(value=0.5),
596+
regularizer=paddle.regularizer.L1Decay(0.1))
597+
bias_attr = paddle.ParamAttr(
598+
name="bias",
599+
initializer=paddle.nn.initializer.Constant(value=0.),
600+
regularizer=None)
601+
program = self.get_program(weight_attr, bias_attr)
602+
ops = program.global_block().ops
603+
604+
self.assertEqual(ops[-1].type, 'momentum')
605+
self.assertEqual(ops[-2].type, 'momentum')
606+
self.assertEqual(ops[-3].type, 'sum')
607+
self.assertEqual(ops[-4].type, 'scale')
608+
self.assertEqual(ops[-5].type, 'sign')
609+
self.assertEqual(ops[-6].type, 'matmul_grad')
610+
if 'weight' in ops[-1].input('Param'):
611+
self.assertEqual(ops[-1].attr('regularization_method'), '')
612+
self.assertEqual(ops[-1].attr('regularization_coeff'), 0)
613+
if 'bias' in ops[-2].input('Param'):
614+
self.assertEqual(ops[-2].attr('regularization_method'), 'l2_decay')
615+
self.assertEqual(ops[-2].attr('regularization_coeff'),
616+
np.float32(0.5))
617+
618+
def test_param_has_no_regularizer(self):
619+
paddle.enable_static()
620+
program = self.get_program(weight_attr=None)
621+
ops = program.global_block().ops
622+
self.assertEqual(ops[-1].attr('regularization_method'), 'l2_decay')
623+
self.assertEqual(ops[-1].attr('regularization_coeff'), np.float32(0.5))
624+
for i in range(len(ops)):
625+
self.assertTrue('sum' not in ops[i].type)
626+
self.assertTrue('scale' not in ops[i].type)
627+
628+
558629
class TestMomentumOpVsMomentumOpWithDecayAPI(unittest.TestCase):
559630
def __update_params(self, momentum, linear):
560631
for i in range(10):

python/paddle/fluid/tests/unittests/test_regularizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def test_l2decay_regularizer(self):
5959
params_grads = append_backward(mean_out)
6060
self.assertEqual(len(params_grads), 1)
6161
count_ops = len(block.ops)
62+
optimizer = paddle.optimizer.Adam()
6263
params_grads = optimizer.append_regularization_ops(params_grads)
6364
self.assertEqual(len(params_grads), 1)
6465
self.assertEqual(len(block.ops), count_ops + 2)
@@ -97,6 +98,7 @@ def test_l2decay_regularizer(self):
9798
params_grads = append_backward(mean_out)
9899
self.assertEqual(len(params_grads), 1)
99100
count_ops = len(block.ops)
101+
optimizer = paddle.optimizer.Adam()
100102
params_grads = optimizer.append_regularization_ops(params_grads)
101103
self.assertEqual(len(params_grads), 1)
102104
self.assertEqual(len(block.ops), count_ops + 3)

python/paddle/optimizer/momentum.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,20 +195,47 @@ def _create_accumulators(self, block, parameters):
195195
)
196196
self._add_accumulator(self._velocity_acc_str, p)
197197

198+
def _create_regularization_of_grad(self, param, grad, regularization=None):
199+
""" Create and add backward regularization Operators
200+
201+
Function helper of append_regularization_ops.
202+
"""
203+
# If ParamAttr is set to L2Decay, we skip doing regularization here. And then we fused
204+
# L2Decay with momentum which can refer to _append_optimize_op below.
205+
if hasattr(param, 'regularizer') and isinstance(param.regularizer,
206+
L2DecayRegularizer):
207+
return grad
208+
return super(Momentum, self)._create_regularization_of_grad(
209+
param, grad, regularization)
210+
198211
def _append_optimize_op(self, block, param_and_grad):
199212
assert isinstance(block, framework.Block)
200213

201214
velocity_acc = self._get_accumulator(self._velocity_acc_str,
202215
param_and_grad[0])
203216
lr = self._create_param_lr(param_and_grad)
204217

218+
# For fusion of momentum and l2decay
219+
param = param_and_grad[0]
220+
regularization_method = self._regularization_method
221+
regularization_coeff = self._regularization_coeff
222+
if hasattr(param, 'regularizer'):
223+
# we skip param's l2decay before, so fuse it with momentum here.
224+
if isinstance(param.regularizer, L2DecayRegularizer):
225+
regularization_method = "l2_decay"
226+
regularization_coeff = param.regularizer._regularization_coeff
227+
# the param's regularization has been done before, we avoid do l2decay in momentum.
228+
elif param.regularizer is not None:
229+
regularization_method = ""
230+
regularization_coeff = 0
231+
205232
if framework.in_dygraph_mode():
206233
_, _ = core.ops.momentum(
207234
param_and_grad[0], param_and_grad[1], velocity_acc, lr,
208235
param_and_grad[0], velocity_acc, 'mu', self._momentum,
209236
'use_nesterov', self._use_nesterov, 'regularization_method',
210-
self._regularization_method, 'regularization_coeff',
211-
self._regularization_coeff)
237+
regularization_method, 'regularization_coeff',
238+
regularization_coeff)
212239
return None
213240

214241
find_master = self._multi_precision and param_and_grad[
@@ -219,8 +246,8 @@ def _append_optimize_op(self, block, param_and_grad):
219246
attrs = {
220247
"mu": self._momentum,
221248
"use_nesterov": self._use_nesterov,
222-
"regularization_method": self._regularization_method,
223-
"regularization_coeff": self._regularization_coeff,
249+
"regularization_method": regularization_method,
250+
"regularization_coeff": regularization_coeff,
224251
"multi_precision": find_master,
225252
"rescale_grad": self._rescale_grad
226253
}

0 commit comments

Comments
 (0)