Skip to content

Commit 880fb83

Browse files
author
Yibing Liu
authored
[cherry-pick] Update lamb optimizer (#18333) (#18380)
* Update lamb optimizer (#18333) * Update lamb optimizer * Regenerate api spec test=release/1.5 * Give an experimental warning test=release/1.5
1 parent 5b103c2 commit 880fb83

File tree

5 files changed

+40
-58
lines changed

5 files changed

+40
-58
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,7 @@ paddle.fluid.optimizer.DGCMomentumOptimizer.backward (ArgSpec(args=['self', 'los
861861
paddle.fluid.optimizer.DGCMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
862862
paddle.fluid.optimizer.DGCMomentumOptimizer.load (ArgSpec(args=['self', 'stat_dict'], varargs=None, keywords=None, defaults=None), ('document', '649a92cf7f1ea28666fd00c4ea01acde'))
863863
paddle.fluid.optimizer.DGCMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'b15cffad0903fc81af77a0580ceb2a9b'))
864-
paddle.fluid.optimizer.LambOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'lamb_weight_decay', 'beta1', 'beta2', 'epsilon', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.01, 0.9, 0.999, 1e-06, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
864+
paddle.fluid.optimizer.LambOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'lamb_weight_decay', 'beta1', 'beta2', 'epsilon', 'regularization', 'exclude_from_weight_decay_fn', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.01, 0.9, 0.999, 1e-06, None, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
865865
paddle.fluid.optimizer.LambOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871'))
866866
paddle.fluid.optimizer.LambOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae'))
867867
paddle.fluid.optimizer.LambOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f'))

paddle/fluid/operators/optimizers/lamb_op.cc

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,23 +60,13 @@ correction. For more information, please refer to https://arxiv.org/abs/1904.009
6060
The updating of parameters follows:
6161
6262
$$
63-
m_t^l &= \beta_1 m_{t - 1}^l + (1 - \beta_1)g_t^l \\
63+
m_t &= \beta_1 m_{t - 1}+ (1 - \beta_1)g_t \\
6464
65-
v_t^l &= \beta_2 v_{t - 1}^l + (1 - \beta_2)g_t^l \odot g_t^l \\
65+
v_t &= \beta_2 v_{t - 1} + (1 - \beta_2)g_t^2 \\
6666
67-
\widehat{m}_t^l &= m_t^l/(1 - \beta_1^t) \\
67+
r_t &= \frac{m_t}{\sqrt{v_t}+\epsilon} \\
6868
69-
\widehat{v}_t^l &= v_t^l/(1 - \beta_2^t) \\
70-
71-
r_1 &= \left \| w_{t-1}^l \right \|_2 \\
72-
73-
r_2 &= \left \| \frac{\widehat{m}_t^l}{\sqrt{\widehat{v}_t^l+\epsilon}} + \lambda w_{t-1}^l \right \|_2 \\
74-
75-
r &= r_1 / r_2 \\
76-
77-
\eta^l &= r \times \eta \\
78-
79-
w_t^l &= w_{t-1}^l -\eta ^l \times (\frac{\widehat{m}_t^l}{\sqrt{\widehat{v}_t^l+\epsilon}} + \lambda w_{t-1}^l)
69+
w_t &= w_{t-1} -\eta_t \frac{\left \| w_{t-1}\right \|}{\left \| r_t + \lambda w_{t-1}\right \|} (r_t + \lambda w_{t-1})
8070
$$
8171
8272
where $m$ is the 1st moment, and $v$ the 2nd moment, $\eta$ the

paddle/fluid/operators/optimizers/lamb_op.h

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,19 +66,14 @@ struct LambMomentUpdateFunctor {
6666
T g = grad_[i];
6767
T mom1 = moment1_[i];
6868
T mom2 = moment2_[i];
69-
T beta1_pow = *beta1_pow_;
70-
T beta2_pow = *beta2_pow_;
7169
T p = param_[i];
7270

7371
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
7472
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
7573

76-
T mom1_h = mom1 / (1 - beta1_pow);
77-
T mom2_h = mom2 / (1 - beta2_pow);
78-
7974
moment1_out_[i] = mom1;
8075
moment2_out_[i] = mom2;
81-
trust_ratio_div_[i] = mom1_h / sqrt(mom2_h + epsilon_) + weight_decay_ * p;
76+
trust_ratio_div_[i] = mom1 / (sqrt(mom2) + epsilon_) + weight_decay_ * p;
8277
}
8378
};
8479

@@ -130,19 +125,14 @@ struct SparseLambMomentUpdateFunctor {
130125
// The following code is same as dense
131126
T mom1 = moment1_[i];
132127
T mom2 = moment2_[i];
133-
T beta1_pow = *beta1_pow_;
134-
T beta2_pow = *beta2_pow_;
135128
T p = param_[i];
136129

137130
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
138131
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
139132

140-
T mom1_h = mom1 / (1 - beta1_pow);
141-
T mom2_h = mom2 / (1 - beta2_pow);
142-
143133
moment1_out_[i] = mom1;
144134
moment2_out_[i] = mom2;
145-
trust_ratio_div_[i] = mom1_h / sqrt(mom2_h + epsilon_) + weight_decay_ * p;
135+
trust_ratio_div_[i] = mom1 / (sqrt(mom2) + epsilon_) + weight_decay_ * p;
146136
}
147137

148138
inline HOSTDEVICE void operator()(size_t i) const {

python/paddle/fluid/optimizer.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2077,30 +2077,20 @@ class LambOptimizer(AdamOptimizer):
20772077
20782078
LAMB Optimizer is designed to scale up the batch size of training without losing
20792079
accuracy, which supports adaptive element-wise updating and accurate layer-wise
2080-
correction. For more information, please refer to `Reducing BERT Pre-Training
2081-
Time from 3 Days to 76 Minutes <https://arxiv.org/abs/1904.00962>`_ .
2080+
correction. For more information, please refer to `Large Batch Optimization for
2081+
Deep Learning: Training BERT in 76 minutes <https://arxiv.org/abs/1904.00962>`_ .
20822082
20832083
The updating of parameters follows:
20842084
20852085
.. math::
20862086
2087-
m_t^l & = \\beta_1 m_{t - 1}^l + (1 - \\beta_1)g_t^l
2087+
m_t &= \\beta_1 m_{t - 1}+ (1 - \\beta_1)g_t \\
20882088
2089-
v_t^l & = \\beta_2 v_{t - 1}^l + (1 - \\beta_2)g_t^l \odot g_t^l
2089+
v_t &= \\beta_2 v_{t - 1} + (1 - \\beta_2)g_t^2 \\
20902090
2091-
\\widehat{m}_t^l & = m_t^l/(1 - \\beta_1^t)
2091+
r_t &= \\frac{m_t}{\\sqrt{v_t}+\\epsilon} \\
20922092
2093-
\\widehat{v}_t^l & = v_t^l/(1 - \\beta_2^t)
2094-
2095-
r_1 & = \\left \| w_{t-1}^l \\right \|_2
2096-
2097-
r_2 & = \\left \| \\frac{\\widehat{m}_t^l}{\\sqrt{\\widehat{v}_t^l+\\epsilon}} + \\lambda w_{t-1}^l \\right \|_2
2098-
2099-
r & = r_1 / r_2
2100-
2101-
\\eta^l & = r \\times \\eta
2102-
2103-
w_t^l & = w_{t-1}^l -\\eta ^l \\times (\\frac{\\widehat{m}_t^l}{\\sqrt{\\widehat{v}_t^l+\\epsilon}} + \\lambda w_{t-1}^l)
2093+
w_t &= w_{t-1} -\\eta_t \\frac{\\left \| w_{t-1}\\right \|}{\\left \| r_t + \\lambda w_{t-1}\\right \|} (r_t + \\lambda w_{t-1})
21042094
21052095
21062096
where :math:`m` is the 1st moment, and :math:`v` the 2nd moment, :math:`\\eta` the
@@ -2114,8 +2104,10 @@ class LambOptimizer(AdamOptimizer):
21142104
beta1 (float): The exponential decay rate for the 1st moment estimates.
21152105
beta2 (float): The exponential decay rate for the 2nd moment estimates.
21162106
epsilon (float): A small float value for numerical stability.
2117-
regularization: A Regularizer, such as
2107+
regularization (Regularizer): A Regularizer, such as
21182108
fluid.regularizer.L1DecayRegularizer.
2109+
exclude_from_weight_decay_fn (function): Exclude a parameter from weight
2110+
decay when **exclude_from_weight_decay_fn(parameter)** returns true.
21192111
name (str|None): An optional name prefix.
21202112
21212113
Examples:
@@ -2127,11 +2119,16 @@ class LambOptimizer(AdamOptimizer):
21272119
hidden = fluid.layers.fc(input=data, size=10)
21282120
cost = fluid.layers.mean(hidden)
21292121
2130-
optimizer = fluid.optimizer.Lamb(learning_rate=0.002)
2122+
def exclude_fn(param):
2123+
return param.name.endswith('.b_0')
2124+
2125+
optimizer = fluid.optimizer.Lamb(learning_rate=0.002,
2126+
exclude_from_weight_decay_fn=exclude_fn)
21312127
optimizer.minimize(cost)
21322128
"""
21332129
_moment1_acc_str = "moment1"
21342130
_moment2_acc_str = "moment2"
2131+
# these two not used in op temporarily
21352132
_beta1_pow_acc_str = "beta1_pow_acc"
21362133
_beta2_pow_acc_str = "beta2_pow_acc"
21372134

@@ -2142,6 +2139,7 @@ def __init__(self,
21422139
beta2=0.999,
21432140
epsilon=1e-6,
21442141
regularization=None,
2142+
exclude_from_weight_decay_fn=None,
21452143
name=None):
21462144
assert learning_rate is not None
21472145
assert lamb_weight_decay is not None
@@ -2157,6 +2155,10 @@ def __init__(self,
21572155
name=name)
21582156
self.type = "lamb"
21592157
self._weight_decay = lamb_weight_decay
2158+
self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn
2159+
print(
2160+
"WARNING: The LAMB optimizer doesn't have official implementation "
2161+
"yet and is still in experimental.")
21602162

21612163
def _append_optimize_op(self, block, param_and_grad):
21622164
assert isinstance(block, framework.Block)
@@ -2170,6 +2172,12 @@ def _append_optimize_op(self, block, param_and_grad):
21702172
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
21712173
param_and_grad[0])
21722174

2175+
if self._exclude_from_weight_decay_fn is not None \
2176+
and self._exclude_from_weight_decay_fn(param_and_grad[0]):
2177+
weight_decay = 0.0
2178+
else:
2179+
weight_decay = self._weight_decay
2180+
21732181
# create the lamb optimize op
21742182
lamb_op = block.append_op(
21752183
type=self.type,
@@ -2191,7 +2199,7 @@ def _append_optimize_op(self, block, param_and_grad):
21912199
"beta1": self._beta1,
21922200
"beta2": self._beta2,
21932201
"epsilon": self._epsilon,
2194-
"weight_decay": self._weight_decay
2202+
"weight_decay": weight_decay
21952203
},
21962204
stop_gradient=True)
21972205

python/paddle/fluid/tests/unittests/test_lamb_op.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -140,15 +140,12 @@ def lamb_step(inputs, attributes):
140140
moment1_out = beta1 * moment1 + (1 - beta1) * grad
141141
moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad)
142142

143-
mom1_tmp = moment1_out / (1 - beta1_pow)
144-
mom2_tmp = moment2_out / (1 - beta2_pow)
145-
146143
r_1 = np.linalg.norm(param)
147-
r_2 = np.linalg.norm(mom1_tmp / np.sqrt(mom2_tmp + epsilon) + weight_decay *
148-
param)
144+
r_2 = np.linalg.norm(moment1_out / (np.sqrt(moment2_out) + epsilon) +
145+
weight_decay * param)
149146
lr_t = lr * r_1 / r_2
150147

151-
param_out = param - lr_t * (mom1_tmp / np.sqrt(mom2_tmp + epsilon) +
148+
param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon) +
152149
weight_decay * param)
153150
return param_out, moment1_out, moment2_out
154151

@@ -190,16 +187,13 @@ def update_mom(row_id, update_value):
190187
1 - beta2) * np.square(update_value)
191188

192189
def update_param():
193-
mom1_tmp = moment1_out / (1 - beta1_pow)
194-
mom2_tmp = moment2_out / (1 - beta2_pow)
195-
196190
r_1 = np.linalg.norm(param)
197-
r_2 = np.linalg.norm(mom1_tmp / np.sqrt(mom2_tmp + epsilon) +
191+
r_2 = np.linalg.norm(moment1_out / (np.sqrt(moment2_out) + epsilon) +
198192
weight_decay * param)
199193
lr_t = lr * r_1 / r_2
200194

201-
param_out = param - lr_t * (mom1_tmp / np.sqrt(mom2_tmp + epsilon) +
202-
weight_decay * param)
195+
param_out = param - lr_t * (moment1_out / (
196+
np.sqrt(moment2_out) + epsilon) + weight_decay * param)
203197

204198
for row_id in range(param_out.shape[0]):
205199
update_value = np.zeros(np_grad[0].shape).astype("float32")

0 commit comments

Comments
 (0)