Skip to content

Commit 1ebd743

Browse files
authored
Add linear learning warmup method in learning rate scheduler. (#16563)
* Add linear learning warmup method This warmup lr can be combinated with other learning rate strategies. For example: decayed_lr = fluid.layers.linear_lr_warmup( fluid.layers.piecewise_decay(boundaries, lr_steps), warmup_steps, start_lr, end_lr)
1 parent a61ed97 commit 1ebd743

File tree

3 files changed

+102
-4
lines changed

3 files changed

+102
-4
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ paddle.fluid.layers.piecewise_decay (ArgSpec(args=['boundaries', 'values'], vara
359359
paddle.fluid.layers.noam_decay (ArgSpec(args=['d_model', 'warmup_steps'], varargs=None, keywords=None, defaults=None), ('document', 'd9a95746353fd574be36dc28d8726c28'))
360360
paddle.fluid.layers.append_LARS (ArgSpec(args=['params_grads', 'learning_rate', 'weight_decay'], varargs=None, keywords=None, defaults=None), ('document', 'd24fa1e7d62ac8a534fc6a86002f84f8'))
361361
paddle.fluid.layers.cosine_decay (ArgSpec(args=['learning_rate', 'step_each_epoch', 'epochs'], varargs=None, keywords=None, defaults=None), ('document', '9588c64c26ffaef3c466e404a6af9d9b'))
362+
paddle.fluid.layers.linear_lr_warmup (ArgSpec(args=['learning_rate', 'warmup_steps', 'start_lr', 'end_lr'], varargs=None, keywords=None, defaults=None), ('document', '2ef3f5ca5cd71ea4217c418e5a7a0565'))
362363
paddle.fluid.contrib.InitState.__init__ (ArgSpec(args=['self', 'init', 'shape', 'value', 'init_boot', 'need_reorder', 'dtype'], varargs=None, keywords=None, defaults=(None, None, 0.0, None, False, 'float32')), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
363364
paddle.fluid.contrib.StateCell.__init__ (ArgSpec(args=['self', 'inputs', 'states', 'out_state', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
364365
paddle.fluid.contrib.StateCell.compute_state (ArgSpec(args=['self', 'inputs'], varargs=None, keywords=None, defaults=None), ('document', '92973b3f222081a1d17069c683cf4a99'))

python/paddle/fluid/layers/learning_rate_scheduler.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
__all__ = [
3434
'exponential_decay', 'natural_exp_decay', 'inverse_time_decay',
3535
'polynomial_decay', 'piecewise_decay', 'noam_decay', 'append_LARS',
36-
'cosine_decay'
36+
'cosine_decay', 'linear_lr_warmup'
3737
]
3838

3939

@@ -383,3 +383,59 @@ def _balanced_weight(param_norm, grad_norm):
383383
/ _balanced_weight(param_norm, grad_norm)
384384
# set back param local learning rate
385385
param.optimize_attr['learning_rate'] = decayed_lr
386+
387+
388+
def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
389+
"""
390+
Applies linear learning rate warmup before the normal learning rate
391+
scheduling.
392+
393+
.. code-block:: python
394+
395+
if global_step < warmup_steps:
396+
linear_step = end_lr - start_lr
397+
lr = start_lr + linear_step * (global_step / warmup_steps)
398+
399+
Args:
400+
learning_rate (float | Variable): A float value or Variable.
401+
warmup_steps (int): The warmup steps.
402+
start_lr (float): The start learning of warmup.
403+
end_lr (float): The end learning of warmup.
404+
405+
Returns:
406+
The decayed learning rate in warmup period.
407+
408+
Examples:
409+
.. code-block:: python
410+
411+
boundaries = [100, 200]
412+
lr_steps = [0.1, 0.01, 0.001]
413+
warmup_steps = 50
414+
start_lr = 1. / 3.
415+
end_lr = 0.1
416+
decayed_lr = fluid.layers.linear_lr_warmup(
417+
fluid.layers.piecewise_decay(boundaries, lr_steps),
418+
warmup_steps, start_lr, end_lr)
419+
420+
"""
421+
assert (isinstance(end_lr, float))
422+
assert (isinstance(start_lr, float))
423+
linear_step = end_lr - start_lr
424+
with default_main_program()._lr_schedule_guard():
425+
lr = tensor.create_global_var(
426+
shape=[1],
427+
value=0.0,
428+
dtype='float32',
429+
persistable=True,
430+
name="learning_rate_warmup")
431+
432+
global_step = _decay_step_counter()
433+
434+
with control_flow.Switch() as switch:
435+
with switch.case(global_step < warmup_steps):
436+
decayed_lr = start_lr + linear_step * (global_step /
437+
float(warmup_steps))
438+
tensor.assign(decayed_lr, lr)
439+
with switch.default():
440+
tensor.assign(learning_rate, lr)
441+
return lr

python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ def check_decay_with_place(self, place, python_decay_fn, fluid_decay_fn,
120120
self.assertAlmostEqual(
121121
python_decayed_lr,
122122
lr_val[0],
123-
msg='Failed fn is {0}, Python result is {1}, Fluid result is {2}'.
123+
msg='Failed lr scheduler is {0}, step {1}, Python result is {2}, Fluid result is {3}'.
124124
format(python_decay_fn.__name__,
125-
str(python_decayed_lr), str(lr_val[0])))
125+
str(step), str(python_decayed_lr), str(lr_val[0])))
126126

127127
def test_decay(self):
128128
common_kwargs_true = {
@@ -164,12 +164,53 @@ def test_decay(self):
164164
]
165165

166166
for py_decay_fn, fluid_decay_fn, kwargs in decay_fns:
167-
print("decay_fn=" + py_decay_fn.__name__ + " kwargs=" + str(kwargs))
167+
print("class=" + self.__class__.__name__ + "decay_fn=" +
168+
py_decay_fn.__name__ + " kwargs=" + str(kwargs))
168169
main_program = framework.Program()
169170
startup_program = framework.Program()
170171
with framework.program_guard(main_program, startup_program):
171172
self.check_decay(py_decay_fn, fluid_decay_fn, kwargs)
172173

173174

175+
def linear_lr_warmup(global_step, warmup_steps, start_lr, end_lr):
176+
linear_step = end_lr - start_lr
177+
decayed_lr = start_lr + linear_step * (global_step / warmup_steps)
178+
return decayed_lr
179+
180+
181+
class TestLinearWamrupLearningRateDecay(TestLearningRateDecay):
182+
def check_decay_with_place(self, place, python_decay_fn, fluid_decay_fn,
183+
kwargs):
184+
main_prog = fluid.Program()
185+
startup_prog = fluid.Program()
186+
187+
warmup_steps = 10
188+
start_lr = 1. / 3.
189+
end_lr = 0.1
190+
191+
with fluid.program_guard(main_prog, startup_prog):
192+
decayed_lr = layers.linear_lr_warmup(
193+
fluid_decay_fn(**kwargs), warmup_steps, start_lr, end_lr)
194+
195+
place = fluid.CPUPlace()
196+
exe = fluid.Executor(place)
197+
exe.run(startup_prog)
198+
199+
for step in range(20):
200+
lr_val, = exe.run(main_prog, feed={}, fetch_list=[decayed_lr])
201+
if step < warmup_steps:
202+
python_decayed_lr = linear_lr_warmup(
203+
float(step), warmup_steps, start_lr, end_lr)
204+
else:
205+
python_decayed_lr = python_decay_fn(
206+
global_step=float(step), **kwargs)
207+
self.assertAlmostEqual(
208+
python_decayed_lr,
209+
lr_val[0],
210+
msg='Test {0} Failed, step {1}, Python result is {2}, Fluid result is {3}'.
211+
format(python_decay_fn.__name__,
212+
str(step), str(python_decayed_lr), str(lr_val[0])))
213+
214+
174215
if __name__ == '__main__':
175216
unittest.main()

0 commit comments

Comments
 (0)