Skip to content

Commit d8d73ff

Browse files
authored
Merge pull request #15584 from velconia/imperative_lr_scheduler
Support imperative learning rate scheduler
2 parents 1ebd743 + 64b0929 commit d8d73ff

File tree

7 files changed

+773
-230
lines changed

7 files changed

+773
-230
lines changed

python/paddle/fluid/dygraph/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,14 @@
3232
from . import checkpoint
3333
from .checkpoint import *
3434

35+
from . import learning_rate_scheduler
36+
from .learning_rate_scheduler import *
37+
3538
__all__ = []
3639
__all__ += layers.__all__
3740
__all__ += base.__all__
3841
__all__ += nn.__all__
3942
__all__ += tracer.__all__
4043
__all__ += profiler.__all__
4144
__all__ += checkpoint.__all__
45+
__all__ += learning_rate_scheduler.__all__
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import math
18+
19+
from .. import unique_name
20+
21+
__all__ = [
22+
'NoamDecay', 'PiecewiseDecay', 'NaturalExpDecay', 'ExponentialDecay',
23+
'InverseTimeDecay', 'PolynomialDecay', 'CosineDecay'
24+
]
25+
26+
27+
class LearningRateDecay(object):
28+
"""
29+
Base class of learning rate decay
30+
"""
31+
32+
def __init__(self, begin=0, step=1, dtype='float32'):
33+
self.step_num = begin
34+
self.step_size = step
35+
self.dtype = dtype
36+
37+
def __call__(self):
38+
lr = self.step()
39+
if isinstance(lr, float):
40+
lr = self.create_lr_var(lr)
41+
self.step_num += self.step_size
42+
return lr
43+
44+
def create_lr_var(self, lr):
45+
from .. import layers
46+
lr = layers.create_global_var(
47+
name=unique_name.generate("learning_rate"),
48+
shape=[1],
49+
value=float(lr),
50+
dtype=self.dtype,
51+
persistable=True)
52+
return lr
53+
54+
def step(self):
55+
raise NotImplementedError()
56+
57+
58+
class PiecewiseDecay(LearningRateDecay):
59+
def __init__(self, boundaries, values, begin, step=1, dtype='float32'):
60+
super(PiecewiseDecay, self).__init__(begin, step, dtype)
61+
self.boundaries = boundaries
62+
self.values = values
63+
64+
self.vars = []
65+
for value in values:
66+
self.vars.append(self.create_lr_var(value))
67+
68+
def step(self):
69+
for i in range(len(self.boundaries)):
70+
if self.step_num < self.boundaries[i]:
71+
return self.vars[i]
72+
return self.vars[len(self.values) - 1]
73+
74+
75+
class NaturalExpDecay(LearningRateDecay):
76+
def __init__(self,
77+
learning_rate,
78+
decay_steps,
79+
decay_rate,
80+
staircase=False,
81+
begin=0,
82+
step=1,
83+
dtype='float32'):
84+
super(NaturalExpDecay, self).__init__(begin, step, dtype)
85+
self.learning_rate = learning_rate
86+
self.decay_steps = decay_steps
87+
self.decay_rate = decay_rate
88+
self.staircase = staircase
89+
90+
def step(self):
91+
from .. import layers
92+
div_res = self.create_lr_var(self.step_num / self.decay_steps)
93+
if self.staircase:
94+
div_res = layers.floor(div_res)
95+
decayed_lr = self.learning_rate * layers.exp(-1 * self.decay_rate *
96+
div_res)
97+
98+
return decayed_lr
99+
100+
101+
class ExponentialDecay(LearningRateDecay):
102+
def __init__(self,
103+
learning_rate,
104+
decay_steps,
105+
decay_rate,
106+
staircase=False,
107+
begin=0,
108+
step=1,
109+
dtype='float32'):
110+
super(ExponentialDecay, self).__init__(begin, step, dtype)
111+
self.learning_rate = learning_rate
112+
self.decay_steps = decay_steps
113+
self.decay_rate = decay_rate
114+
self.staircase = staircase
115+
116+
def step(self):
117+
from .. import layers
118+
div_res = self.create_lr_var(self.step_num / self.decay_steps)
119+
if self.staircase:
120+
div_res = layers.floor(div_res)
121+
122+
decayed_lr = self.learning_rate * (self.decay_rate**div_res)
123+
124+
return decayed_lr
125+
126+
127+
class InverseTimeDecay(LearningRateDecay):
128+
def __init__(self,
129+
learning_rate,
130+
decay_steps,
131+
decay_rate,
132+
staircase=False,
133+
begin=0,
134+
step=1,
135+
dtype='float32'):
136+
super(InverseTimeDecay, self).__init__(begin, step, dtype)
137+
self.learning_rate = learning_rate
138+
self.decay_steps = decay_steps
139+
self.decay_rate = decay_rate
140+
self.staircase = staircase
141+
142+
def step(self):
143+
from .. import layers
144+
div_res = self.create_lr_var(self.step_num / self.decay_steps)
145+
if self.staircase:
146+
div_res = layers.floor(div_res)
147+
148+
decayed_lr = self.learning_rate / (1 + self.decay_rate * div_res)
149+
150+
return decayed_lr
151+
152+
153+
class PolynomialDecay(LearningRateDecay):
154+
def __init__(self,
155+
learning_rate,
156+
decay_steps,
157+
end_learning_rate=0.0001,
158+
power=1.0,
159+
cycle=False,
160+
begin=0,
161+
step=1,
162+
dtype='float32'):
163+
super(PolynomialDecay, self).__init__(begin, step, dtype)
164+
self.learning_rate = learning_rate
165+
self.decay_steps = decay_steps
166+
self.end_learning_rate = end_learning_rate
167+
self.power = power
168+
self.cycle = cycle
169+
170+
def step(self):
171+
from .. import layers
172+
tmp_step_num = self.step_num
173+
tmp_decay_steps = self.decay_steps
174+
if self.cycle:
175+
div_res = layers.ceil(
176+
self.create_lr_var(tmp_step_num / float(self.decay_steps)))
177+
178+
if tmp_step_num == 0:
179+
div_res = self.create_lr_var(1.0)
180+
tmp_decay_steps = self.decay_steps * div_res
181+
else:
182+
tmp_step_num = self.create_lr_var(tmp_step_num
183+
if tmp_step_num < self.decay_steps
184+
else self.decay_steps)
185+
186+
decayed_lr = (self.learning_rate - self.end_learning_rate) * \
187+
((1 - tmp_step_num / tmp_decay_steps) ** self.power) + self.end_learning_rate
188+
return decayed_lr
189+
190+
191+
class CosineDecay(LearningRateDecay):
192+
def __init__(self,
193+
learning_rate,
194+
step_each_epoch,
195+
epochs,
196+
begin=0,
197+
step=1,
198+
dtype='float32'):
199+
super(CosineDecay, self).__init__(begin, step, dtype)
200+
self.learning_rate = learning_rate
201+
self.step_each_epoch = step_each_epoch
202+
self.epochs = epochs
203+
204+
def step(self):
205+
from .. import layers
206+
cur_epoch = layers.floor(
207+
self.create_lr_var(self.step_num / self.step_each_epoch))
208+
decayed_lr = self.learning_rate * 0.5 * (
209+
layers.cos(cur_epoch * math.pi / self.epochs) + 1)
210+
return decayed_lr
211+
212+
213+
class NoamDecay(LearningRateDecay):
214+
def __init__(self, d_model, warmup_steps, begin=1, step=1, dtype='float32'):
215+
super(NoamDecay, self).__init__(begin, step, dtype)
216+
self.d_model = d_model
217+
self.warmup_steps = warmup_steps
218+
219+
def step(self):
220+
from .. import layers
221+
a = self.create_lr_var(self.step_num**-0.5)
222+
b = self.create_lr_var((self.warmup_steps**-1.5) * self.step_num)
223+
lr_value = (self.d_model**-0.5) * layers.elementwise_min(a, b)
224+
return lr_value

0 commit comments

Comments
 (0)