Skip to content

Commit 173c406

Browse files
committed
feature: combine AdaBoundW into AdaBound optimizer
1 parent eec24fd commit 173c406

File tree

1 file changed

+13
-143
lines changed

1 file changed

+13
-143
lines changed

pytorch_optimizer/adabound.py

Lines changed: 13 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def __init__(
3838
gamma: float = 1e-3,
3939
eps: float = 1e-8,
4040
weight_decay: float = 0.0,
41+
weight_decouple: bool = True,
42+
fixed_decay: bool = False,
4143
amsbound: bool = False,
4244
):
4345
"""AdaBound optimizer
@@ -46,6 +48,8 @@ def __init__(
4648
:param final_lr: float. final learning rate
4749
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
4850
:param gamma: float. convergence speed of the bound functions
51+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW
52+
:param fixed_decay: bool.
4953
:param eps: float. term added to the denominator to improve numerical stability
5054
:param weight_decay: float. weight decay (L2 penalty)
5155
:param amsbound: bool. whether to use the AMSBound variant
@@ -54,6 +58,8 @@ def __init__(
5458
self.betas = betas
5559
self.eps = eps
5660
self.weight_decay = weight_decay
61+
self.weight_decouple = weight_decouple
62+
self.fixed_decay = fixed_decay
5763

5864
defaults: DEFAULT_PARAMETERS = dict(
5965
lr=lr,
@@ -119,150 +125,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
119125

120126
state['step'] += 1
121127

122-
if group['weight_decay'] != 0:
123-
grad = grad.add(group['weight_decay'], p.data)
124-
125-
# Decay the first and second moment running average coefficient
126-
exp_avg.mul_(beta1).add_(1 - beta1, grad)
127-
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
128-
if amsbound:
129-
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
130-
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
128+
if self.weight_decouple:
129+
if not self.fixed_decay:
130+
p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
131+
else:
132+
p.data.mul_(1.0 - group['weight_decay'])
131133
else:
132-
denom = exp_avg_sq.sqrt().add_(group['eps'])
133-
134-
bias_correction1 = 1 - beta1 ** state['step']
135-
bias_correction2 = 1 - beta2 ** state['step']
136-
step_size = (
137-
group['lr']
138-
* math.sqrt(bias_correction2)
139-
/ bias_correction1
140-
)
141-
142-
final_lr = group['final_lr'] * group['lr'] / base_lr
143-
lower_bound = final_lr * (
144-
1 - 1 / (group['gamma'] * state['step'] + 1)
145-
)
146-
upper_bound = final_lr * (
147-
1 + 1 / (group['gamma'] * state['step'])
148-
)
149-
step_size = torch.full_like(denom, step_size)
150-
step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(
151-
exp_avg
152-
)
153-
154-
p.data.add_(-step_size)
155-
156-
return loss
157-
158-
159-
class AdaBoundW(Optimizer):
160-
"""
161-
Reference : https://github.com/Luolc/AdaBound
162-
Example :
163-
from pytorch_optimizer import AdaBoundW
164-
...
165-
model = YourModel()
166-
optimizer = AdaBoundW(model.parameters())
167-
...
168-
for input, output in data:
169-
optimizer.zero_grad()
170-
loss = loss_function(output, model(input))
171-
loss.backward()
172-
optimizer.step()
173-
"""
174-
175-
def __init__(
176-
self,
177-
params: PARAMS,
178-
lr: float = 1e-3,
179-
betas: BETAS = (0.9, 0.999),
180-
final_lr: float = 0.1,
181-
gamma: float = 1e-3,
182-
eps: float = 1e-8,
183-
weight_decay: float = 0.0,
184-
amsbound: bool = False,
185-
):
186-
"""AdaBound optimizer with decoupled weight decay
187-
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
188-
:param lr: float. learning rate
189-
:param final_lr: float. final learning rate
190-
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
191-
:param gamma: float. convergence speed of the bound functions
192-
:param eps: float. term added to the denominator to improve numerical stability
193-
:param weight_decay: float. weight decay (L2 penalty)
194-
:param amsbound: bool. whether to use the AMSBound variant
195-
"""
196-
self.lr = lr
197-
self.betas = betas
198-
self.eps = eps
199-
self.weight_decay = weight_decay
200-
201-
defaults: DEFAULT_PARAMETERS = dict(
202-
lr=lr,
203-
betas=betas,
204-
final_lr=final_lr,
205-
gamma=gamma,
206-
eps=eps,
207-
weight_decay=weight_decay,
208-
amsbound=amsbound,
209-
)
210-
super().__init__(params, defaults)
211-
212-
self.base_lrs = [group['lr'] for group in self.param_groups]
213-
214-
def check_valid_parameters(self):
215-
if 0.0 > self.lr:
216-
raise ValueError(f'Invalid learning rate : {self.lr}')
217-
if 0.0 > self.eps:
218-
raise ValueError(f'Invalid eps : {self.eps}')
219-
if 0.0 > self.weight_decay:
220-
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
221-
if not 0.0 <= self.betas[0] < 1.0:
222-
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
223-
if not 0.0 <= self.betas[1] < 1.0:
224-
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
225-
226-
def __setstate__(self, state: STATE):
227-
super().__setstate__(state)
228-
for group in self.param_groups:
229-
group.setdefault('amsbound', False)
230-
231-
def step(self, closure: CLOSURE = None) -> LOSS:
232-
loss: LOSS = None
233-
if closure is not None:
234-
loss = closure()
235-
236-
for group, base_lr in zip(self.param_groups, self.base_lrs):
237-
for p in group['params']:
238-
if p.grad is None:
239-
continue
240-
241-
p.mul_(1 - base_lr * group['weight_decay'])
242-
243-
grad = p.grad.data
244-
if grad.is_sparse:
245-
raise RuntimeError(
246-
'AdaBound does not support sparse gradients'
247-
)
248-
249-
amsbound = group['amsbound']
250-
251-
state = self.state[p]
252-
253-
if len(state) == 0:
254-
state['step'] = 0
255-
state['exp_avg'] = torch.zeros_like(p)
256-
state['exp_avg_sq'] = torch.zeros_like(p)
257-
if amsbound:
258-
state['max_exp_avg_sq'] = torch.zeros_like(p)
259-
260-
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
261-
if amsbound:
262-
max_exp_avg_sq = state['max_exp_avg_sq']
263-
beta1, beta2 = group['betas']
264-
265-
state['step'] += 1
134+
if group['weight_decay'] != 0:
135+
grad.add_(p.data, alpha=group['weight_decay'])
266136

267137
# Decay the first and second moment running average coefficient
268138
exp_avg.mul_(beta1).add_(1 - beta1, grad)

0 commit comments

Comments
 (0)