Skip to content

Commit d171f37

Browse files
authored
[CHERR-PICK1.8]add base class LearningRateEpochDecay, and MultiStepDecay, StepDecay (#25277)
* CHERR-PICK1.8,add base class of LearningRateEpochDecay, and API: MultiStepDecay, and API: StepDecay,test=release/1.8 * fix unittest to add coverage,test=develop
1 parent 43facfd commit d171f37

File tree

3 files changed

+356
-69
lines changed

3 files changed

+356
-69
lines changed

python/paddle/fluid/dygraph/learning_rate_scheduler.py

Lines changed: 233 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
__all__ = [
2424
'NoamDecay', 'PiecewiseDecay', 'NaturalExpDecay', 'ExponentialDecay',
2525
'InverseTimeDecay', 'PolynomialDecay', 'CosineDecay', 'LinearLrWarmup',
26-
'ReduceLROnPlateau'
26+
'ReduceLROnPlateau', 'StepDecay', 'MultiStepDecay'
2727
]
2828

2929

@@ -72,6 +72,8 @@ def step(self):
7272

7373
class PiecewiseDecay(LearningRateDecay):
7474
"""
75+
:api_attr: imperative
76+
7577
Piecewise decay scheduler.
7678
7779
The algorithm can be described as the code below.
@@ -131,6 +133,8 @@ def step(self):
131133

132134
class NaturalExpDecay(LearningRateDecay):
133135
"""
136+
:api_attr: imperative
137+
134138
Applies natural exponential decay to the initial learning rate.
135139
136140
The algorithm can be described as following.
@@ -183,7 +187,6 @@ class NaturalExpDecay(LearningRateDecay):
183187
staircase=True),
184188
parameter_list=emb.parameters())
185189
186-
187190
"""
188191

189192
def __init__(self,
@@ -213,6 +216,8 @@ def step(self):
213216

214217
class ExponentialDecay(LearningRateDecay):
215218
"""
219+
:api_attr: imperative
220+
216221
Applies exponential decay to the learning rate.
217222
218223
The algorithm can be described as following.
@@ -293,6 +298,8 @@ def step(self):
293298

294299
class InverseTimeDecay(LearningRateDecay):
295300
"""
301+
:api_attr: imperative
302+
296303
Applies inverse time decay to the initial learning rate.
297304
298305
The algorithm can be described as following.
@@ -369,6 +376,8 @@ def step(self):
369376

370377
class PolynomialDecay(LearningRateDecay):
371378
"""
379+
:api_attr: imperative
380+
372381
Applies polynomial decay to the initial learning rate.
373382
374383
The algorithm can be described as following.
@@ -461,6 +470,8 @@ def step(self):
461470

462471
class CosineDecay(LearningRateDecay):
463472
"""
473+
:api_attr: imperative
474+
464475
Applies cosine decay to the learning rate.
465476
466477
The algorithm can be described as following.
@@ -517,6 +528,8 @@ def step(self):
517528

518529
class NoamDecay(LearningRateDecay):
519530
"""
531+
:api_attr: imperative
532+
520533
Applies Noam decay to the initial learning rate.
521534
522535
The algorithm can be described as following.
@@ -582,6 +595,8 @@ def step(self):
582595

583596
class LinearLrWarmup(LearningRateDecay):
584597
"""
598+
:api_attr: imperative
599+
585600
This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling.
586601
For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
587602
@@ -670,6 +685,8 @@ def step(self):
670685

671686
class ReduceLROnPlateau(LearningRateDecay):
672687
"""
688+
:api_attr: imperative
689+
673690
Reduce learning rate when ``loss`` has stopped descending. Models often benefit from reducing the learning rate
674691
by 2 to 10 times once model performance has no longer improvement.
675692
@@ -774,7 +791,6 @@ def __init__(self,
774791
raise ValueError('threshold mode ' + threshold_mode +
775792
' is unknown!')
776793
self.threshold_mode = threshold_mode
777-
778794
check_type(learning_rate, 'learning_rate', (float, int, Variable),
779795
'ReduceLROnPlateau')
780796
if isinstance(learning_rate, (float, int)):
@@ -856,3 +872,217 @@ def _is_better(self, current, best):
856872

857873
else:
858874
return current > best + self.threshold
875+
876+
877+
class _LearningRateEpochDecay(LearningRateDecay):
878+
"""
879+
:api_attr: imperative
880+
881+
Base class of learning rate decay, which is updated each epoch.
882+
883+
Define the common interface of an _LearningRateEpochDecay.
884+
User should not use this class directly,
885+
but need to use one of it's implementation. And invoke method: `epoch()` each epoch.
886+
"""
887+
888+
def __init__(self, learning_rate, dtype=None):
889+
if not isinstance(learning_rate, (float, int)):
890+
raise TypeError(
891+
"The type of 'learning_rate' must be 'float, int', but received %s."
892+
% type(learning_rate))
893+
if learning_rate >= 1.0:
894+
raise ValueError("The initial learning rate")
895+
896+
self.base_lr = float(learning_rate)
897+
898+
self.epoch_num = -1
899+
if dtype is None:
900+
self.dtype = "float32"
901+
self.learning_rate = self.create_lr_var(self.base_lr)
902+
903+
self.epoch()
904+
905+
def __call__(self):
906+
"""
907+
Return last computed learning rate on current epoch.
908+
"""
909+
return self.learning_rate
910+
911+
def epoch(self, epoch=None):
912+
"""
913+
compueted learning_rate and update it when invoked.
914+
"""
915+
if epoch is None:
916+
self.epoch_num += 1
917+
else:
918+
self.epoch_num = epoch
919+
920+
self.learning_rate = self.get_lr()
921+
if isinstance(self.learning_rate, float):
922+
self.learning_rate = self.create_lr_var(self.learning_rate)
923+
924+
def get_lr(self):
925+
raise NotImplementedError
926+
927+
928+
class StepDecay(_LearningRateEpochDecay):
929+
"""
930+
:api_attr: imperative
931+
932+
Decays the learning rate of ``optimizer`` by ``decay_rate`` every ``step_size`` number of epoch.
933+
934+
The algorithm can be described as the code below.
935+
936+
.. code-block:: text
937+
938+
learning_rate = 0.5
939+
step_size = 30
940+
decay_rate = 0.1
941+
942+
learning_rate = 0.5 if epoch < 30
943+
learning_rate = 0.05 if 30 <= epoch < 60
944+
learning_rate = 0.005 if 60 <= epoch < 90
945+
...
946+
947+
Parameters:
948+
learning_rate (float|int): The initial learning rate. It can be set to python float or int number.
949+
step_size (int): Period of learning rate decay..
950+
decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` .
951+
It should be less than 1.0. Default: 0.1.
952+
953+
Returns:
954+
None.
955+
956+
Examples:
957+
.. code-block:: python
958+
959+
import paddle.fluid as fluid
960+
import numpy as np
961+
with fluid.dygraph.guard():
962+
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
963+
linear = fluid.dygraph.Linear(10, 10)
964+
input = fluid.dygraph.to_variable(x)
965+
scheduler = fluid.dygraph.StepDecay(0.5, step_size=3)
966+
adam = fluid.optimizer.Adam(learning_rate = scheduler, parameter_list = linear.parameters())
967+
968+
for epoch in range(9):
969+
for batch_id in range(5):
970+
out = linear(input)
971+
loss = fluid.layers.reduce_mean(out)
972+
adam.minimize(loss)
973+
scheduler.epoch()
974+
975+
print("epoch:{}, current lr is {}" .format(epoch, adam.current_step_lr()))
976+
# epoch:0, current lr is 0.5
977+
# epoch:1, current lr is 0.5
978+
# epoch:2, current lr is 0.5
979+
# epoch:3, current lr is 0.05
980+
# epoch:4, current lr is 0.05
981+
# epoch:5, current lr is 0.05
982+
# epoch:6, current lr is 0.005
983+
# epoch:7, current lr is 0.005
984+
# epoch:8, current lr is 0.005
985+
986+
"""
987+
988+
def __init__(self, learning_rate, step_size, decay_rate=0.1):
989+
if not isinstance(step_size, int):
990+
raise TypeError(
991+
"The type of 'step_size' must be 'int', but received %s." %
992+
type(step_size))
993+
if decay_rate >= 1.0:
994+
raise ValueError('decay_rate should be < 1.0.')
995+
996+
self.step_size = step_size
997+
self.decay_rate = decay_rate
998+
super(StepDecay, self).__init__(learning_rate)
999+
1000+
def get_lr(self):
1001+
decay_rate = self.create_lr_var(self.decay_rate)
1002+
i = self.epoch_num // self.step_size
1003+
return self.base_lr * (decay_rate**i)
1004+
1005+
1006+
class MultiStepDecay(_LearningRateEpochDecay):
1007+
"""
1008+
:api_attr: imperative
1009+
1010+
Decays the learning rate of ``optimizer`` by ``decay_rate`` once ``epoch`` reaches one of the milestones.
1011+
1012+
The algorithm can be described as the code below.
1013+
1014+
.. code-block:: text
1015+
1016+
learning_rate = 0.5
1017+
milestones = [30, 50]
1018+
decay_rate = 0.1
1019+
if epoch < 30:
1020+
learning_rate = 0.5
1021+
elif epoch < 50:
1022+
learning_rate = 0.05
1023+
else:
1024+
learning_rate = 0.005
1025+
1026+
Parameters:
1027+
learning_rate (float|int): The initial learning rate. It can be set to python float or int number. If it
1028+
milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
1029+
decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` .
1030+
It should be less than 1.0. Default: 0.1.
1031+
1032+
Returns:
1033+
None.
1034+
1035+
Examples:
1036+
.. code-block:: python
1037+
1038+
import paddle.fluid as fluid
1039+
import numpy as np
1040+
with fluid.dygraph.guard():
1041+
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
1042+
linear = fluid.dygraph.Linear(10, 10)
1043+
input = fluid.dygraph.to_variable(x)
1044+
scheduler = fluid.dygraph.MultiStepDecay(0.5, milestones=[3, 5])
1045+
adam = fluid.optimizer.Adam(learning_rate = scheduler, parameter_list = linear.parameters())
1046+
1047+
for epoch in range(6):
1048+
for batch_id in range(5):
1049+
out = linear(input)
1050+
loss = fluid.layers.reduce_mean(out)
1051+
adam.minimize(loss)
1052+
scheduler.epoch()
1053+
1054+
print("epoch:{}, current lr is {}" .format(epoch, adam.current_step_lr()))
1055+
# epoch:0, current lr is 0.5
1056+
# epoch:1, current lr is 0.5
1057+
# epoch:2, current lr is 0.5
1058+
# epoch:3, current lr is 0.05
1059+
# epoch:4, current lr is 0.05
1060+
# epoch:5, current lr is 0.005
1061+
1062+
"""
1063+
1064+
def __init__(self, learning_rate, milestones, decay_rate=0.1):
1065+
if not isinstance(milestones, (tuple, list)):
1066+
raise TypeError(
1067+
"The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
1068+
% type(milestones))
1069+
1070+
if not all([
1071+
milestones[i] < milestones[i + 1]
1072+
for i in range(len(milestones) - 1)
1073+
]):
1074+
raise ValueError('The elements of milestones must be incremented')
1075+
if decay_rate >= 1.0:
1076+
raise ValueError('decay_rate should be < 1.0.')
1077+
1078+
self.milestones = milestones
1079+
self.decay_rate = decay_rate
1080+
super(MultiStepDecay, self).__init__(learning_rate)
1081+
1082+
def get_lr(self):
1083+
decay_rate = self.create_lr_var(self.decay_rate)
1084+
for i in range(len(self.milestones)):
1085+
if self.epoch_num < self.milestones[i]:
1086+
return self.base_lr * (decay_rate**i)
1087+
1088+
return self.base_lr * (decay_rate**len(self.milestones))

python/paddle/fluid/layers/learning_rate_scheduler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,6 @@ def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
498498
Returns:
499499
Variable: Warm-up learning rate with the same data type as learning_rate.
500500
501-
502501
Examples:
503502
504503
.. code-block:: python

0 commit comments

Comments
 (0)