Skip to content

Commit 7d40232

Browse files
committed
feature: AdaBound optimizer
1 parent 1e27057 commit 7d40232

File tree

1 file changed

+160
-0
lines changed

1 file changed

+160
-0
lines changed

pytorch_optimizer/adabound.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,149 @@
1313
)
1414

1515

16+
class AdaBound(Optimizer):
17+
"""
18+
Reference : https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py
19+
Example :
20+
from pytorch_optimizer import AdaBound
21+
...
22+
model = YourModel()
23+
optimizer = AdaBound(model.parameters())
24+
...
25+
for input, output in data:
26+
optimizer.zero_grad()
27+
loss = loss_function(output, model(input))
28+
loss.backward()
29+
optimizer.step()
30+
"""
31+
32+
def __init__(
33+
self,
34+
params: PARAMS,
35+
lr: float = 1e-3,
36+
betas: BETAS = (0.9, 0.999),
37+
final_lr: float = 0.1,
38+
gamma: float = 1e-3,
39+
eps: float = 1e-8,
40+
weight_decay: float = 0.0,
41+
amsbound: bool = False,
42+
):
43+
"""AdaBound optimizer
44+
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
45+
:param lr: float. learning rate
46+
:param final_lr: float. final learning rate
47+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
48+
:param gamma: float. convergence speed of the bound functions
49+
:param eps: float. term added to the denominator to improve numerical stability
50+
:param weight_decay: float. weight decay (L2 penalty)
51+
:param amsbound: bool. whether to use the AMSBound variant
52+
"""
53+
self.lr = lr
54+
self.betas = betas
55+
self.eps = eps
56+
self.weight_decay = weight_decay
57+
58+
defaults: DEFAULT_PARAMETERS = dict(
59+
lr=lr,
60+
betas=betas,
61+
final_lr=final_lr,
62+
gamma=gamma,
63+
eps=eps,
64+
weight_decay=weight_decay,
65+
amsbound=amsbound,
66+
)
67+
super().__init__(params, defaults)
68+
69+
self.base_lrs = [group['lr'] for group in self.param_groups]
70+
71+
def check_valid_parameters(self):
72+
if 0.0 > self.lr:
73+
raise ValueError(f'Invalid learning rate : {self.lr}')
74+
if 0.0 > self.eps:
75+
raise ValueError(f'Invalid eps : {self.eps}')
76+
if 0.0 > self.weight_decay:
77+
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
78+
if not 0.0 <= self.betas[0] < 1.0:
79+
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
80+
if not 0.0 <= self.betas[1] < 1.0:
81+
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
82+
83+
def __setstate__(self, state: STATE):
84+
super().__setstate__(state)
85+
for group in self.param_groups:
86+
group.setdefault('amsbound', False)
87+
88+
def step(self, closure: CLOSURE = None) -> LOSS:
89+
loss: LOSS = None
90+
if closure is not None:
91+
loss = closure()
92+
93+
for group, base_lr in zip(self.param_groups, self.base_lrs):
94+
for p in group['params']:
95+
if p.grad is None:
96+
continue
97+
98+
grad = p.grad.data
99+
if grad.is_sparse:
100+
raise RuntimeError(
101+
'AdaBound does not support sparse gradients'
102+
)
103+
104+
amsbound = group['amsbound']
105+
106+
state = self.state[p]
107+
108+
if len(state) == 0:
109+
state['step'] = 0
110+
state['exp_avg'] = torch.zeros_like(p)
111+
state['exp_avg_sq'] = torch.zeros_like(p)
112+
if amsbound:
113+
state['max_exp_avg_sq'] = torch.zeros_like(p)
114+
115+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
116+
if amsbound:
117+
max_exp_avg_sq = state['max_exp_avg_sq']
118+
beta1, beta2 = group['betas']
119+
120+
state['step'] += 1
121+
122+
if group['weight_decay'] != 0:
123+
grad = grad.add(group['weight_decay'], p.data)
124+
125+
# Decay the first and second moment running average coefficient
126+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
127+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
128+
if amsbound:
129+
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
130+
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
131+
else:
132+
denom = exp_avg_sq.sqrt().add_(group['eps'])
133+
134+
bias_correction1 = 1 - beta1 ** state['step']
135+
bias_correction2 = 1 - beta2 ** state['step']
136+
step_size = (
137+
group['lr']
138+
* math.sqrt(bias_correction2)
139+
/ bias_correction1
140+
)
141+
142+
final_lr = group['final_lr'] * group['lr'] / base_lr
143+
lower_bound = final_lr * (
144+
1 - 1 / (group['gamma'] * state['step'] + 1)
145+
)
146+
upper_bound = final_lr * (
147+
1 + 1 / (group['gamma'] * state['step'])
148+
)
149+
step_size = torch.full_like(denom, step_size)
150+
step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(
151+
exp_avg
152+
)
153+
154+
p.data.add_(-step_size)
155+
156+
return loss
157+
158+
16159
class AdaBoundW(Optimizer):
17160
"""
18161
Reference : https://github.com/Luolc/AdaBound
@@ -50,6 +193,11 @@ def __init__(
50193
:param weight_decay: float. weight decay (L2 penalty)
51194
:param amsbound: bool. whether to use the AMSBound variant
52195
"""
196+
self.lr = lr
197+
self.betas = betas
198+
self.eps = eps
199+
self.weight_decay = weight_decay
200+
53201
defaults: DEFAULT_PARAMETERS = dict(
54202
lr=lr,
55203
betas=betas,
@@ -63,6 +211,18 @@ def __init__(
63211

64212
self.base_lrs = [group['lr'] for group in self.param_groups]
65213

214+
def check_valid_parameters(self):
215+
if 0.0 > self.lr:
216+
raise ValueError(f'Invalid learning rate : {self.lr}')
217+
if 0.0 > self.eps:
218+
raise ValueError(f'Invalid eps : {self.eps}')
219+
if 0.0 > self.weight_decay:
220+
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
221+
if not 0.0 <= self.betas[0] < 1.0:
222+
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
223+
if not 0.0 <= self.betas[1] < 1.0:
224+
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
225+
66226
def __setstate__(self, state: STATE):
67227
super().__setstate__(state)
68228
for group in self.param_groups:

0 commit comments

Comments
 (0)