Skip to content

Commit 4609df4

Browse files
committed
feature: GrokFast optimizer
1 parent 7a39a1e commit 4609df4

File tree

1 file changed

+212
-0
lines changed

1 file changed

+212
-0
lines changed
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import math
2+
from collections import deque
3+
from typing import Dict, Literal, Optional
4+
5+
import torch
6+
from torch import nn
7+
from torch.optim.optimizer import Optimizer
8+
9+
from pytorch_optimizer.base.exception import NoSparseGradientError
10+
from pytorch_optimizer.base.optimizer import BaseOptimizer
11+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
12+
13+
FILTER_TYPE = Literal['mean', 'sum']
14+
15+
16+
@torch.no_grad()
17+
def gradfilter_ma(
18+
model: nn.Module,
19+
grads: Optional[Dict[str, deque]] = None,
20+
window_size: int = 100,
21+
lamb: float = 5.0,
22+
filter_type: FILTER_TYPE = 'mean',
23+
warmup: bool = True,
24+
) -> Dict[str, deque]:
25+
r"""Grokfast-MA.
26+
27+
:param model: nn.Module. model that contains every trainable parameters.
28+
:param grads: Optional[Dict[str, deque]]. running memory (Queue for windowed moving average). initialize by setting
29+
it to None. feed the output of the method recursively after on.
30+
:param window_size: int. the width of the filter window. additional memory requirements increases linearly with
31+
respect to the windows size.
32+
:param lamb: float. amplifying factor hyperparameter of the filter.
33+
:param filter_type: FILTER_TYPE. aggregation method for the running queue.
34+
:param warmup: bool. if true, filter is not applied until the queue is filled.
35+
"""
36+
if grads is None:
37+
grads = {n: deque(maxlen=window_size) for n, p in model.named_parameters() if p.requires_grad}
38+
39+
for n, p in model.named_parameters():
40+
if p.requires_grad:
41+
grads[n].append(p.grad)
42+
43+
if not warmup or len(grads[n]) == window_size:
44+
if filter_type == 'mean':
45+
avg = sum(grads[n]) / len(grads[n])
46+
elif filter_type == 'sum':
47+
avg = sum(grads[n])
48+
else:
49+
raise ValueError(f'Unrecognized filter_type {filter_type}')
50+
51+
p.grad.add_(avg, alpha=lamb)
52+
53+
return grads
54+
55+
56+
@torch.no_grad()
57+
def gradfilter_ema(
58+
model: nn.Module,
59+
grads: Optional[Dict[str, torch.Tensor]] = None,
60+
alpha: float = 0.98,
61+
lamb: float = 2.0,
62+
) -> Dict[str, torch.Tensor]:
63+
r"""Grokfast.
64+
65+
:param model: nn.Module. model that contains every trainable parameters.
66+
:param grads: Optional[Dict[str, deque]]. running memory (EMA). Initialize by setting it to None. Feed the output
67+
of the method recursively after on.
68+
:param alpha: int. momentum hyperparameter of the EMA.
69+
:param lamb: float. amplifying factor hyperparameter of the filter.
70+
"""
71+
if grads is None:
72+
grads = {n: p.grad for n, p in model.named_parameters() if p.requires_grad}
73+
74+
for n, p in model.named_parameters():
75+
if p.requires_grad:
76+
grads[n].mul_(alpha).add_(p.grad, alpha=1.0 - alpha)
77+
p.grad.add_(grads[n], alpha=lamb)
78+
79+
return grads
80+
81+
82+
class GrokFastAdamW(Optimizer, BaseOptimizer):
83+
r"""Accelerated Grokking by Amplifying Slow Gradients with AdamW.
84+
85+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
86+
:param lr: float. learning rate.
87+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
88+
:param grokfast: bool. whether to use grokfast.
89+
:param grokfast_alpha: float. momentum hyperparameter of the EMA.
90+
:param grokfast_lamb: float. amplifying factor hyperparameter of the filter..
91+
:param grokfast_after_step: int. warmup step for grokfast.
92+
:param weight_decay: float. weight decay (L2 penalty).
93+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
94+
:param fixed_decay: bool. fix weight decay.
95+
:param eps: float. term added to the denominator to improve numerical stability.
96+
"""
97+
98+
def __init__(
99+
self,
100+
params: PARAMETERS,
101+
lr: float = 1e-4,
102+
betas: BETAS = (0.9, 0.99),
103+
grokfast: bool = True,
104+
grokfast_alpha: float = 0.98,
105+
grokfast_lamb: float = 2.0,
106+
grokfast_after_step: int = 0,
107+
weight_decay: float = 0.0,
108+
weight_decouple: bool = True,
109+
fixed_decay: bool = False,
110+
normalize_lr: bool = True,
111+
eps: float = 1e-8,
112+
):
113+
self.validate_learning_rate(lr)
114+
self.validate_betas(betas)
115+
self.validate_non_negative(weight_decay, 'weight_decay')
116+
self.validate_range(grokfast_alpha, 'grokfast_alpha', 0.0, 1.0)
117+
self.validate_non_negative(eps, 'eps')
118+
119+
if grokfast and normalize_lr:
120+
lr /= 1.0 + grokfast_lamb
121+
122+
defaults: DEFAULTS = {
123+
'lr': lr,
124+
'betas': betas,
125+
'weight_decay': weight_decay,
126+
'weight_decouple': weight_decouple,
127+
'fixed_decay': fixed_decay,
128+
'grokfast': grokfast,
129+
'grokfast_alpha': grokfast_alpha,
130+
'grokfast_lamb': grokfast_lamb,
131+
'grokfast_after_step': grokfast_after_step,
132+
'eps': eps,
133+
}
134+
super().__init__(params, defaults)
135+
136+
def __str__(self) -> str:
137+
return 'GrokFastAdamW'
138+
139+
@torch.no_grad()
140+
def reset(self):
141+
for group in self.param_groups:
142+
group['step'] = 0
143+
for p in group['params']:
144+
state = self.state[p]
145+
146+
state['exp_avg'] = torch.zeros_like(p)
147+
state['exp_avg_sq'] = torch.zeros_like(p)
148+
149+
@torch.no_grad()
150+
def step(self, closure: CLOSURE = None) -> LOSS:
151+
loss: LOSS = None
152+
if closure is not None:
153+
with torch.enable_grad():
154+
loss = closure()
155+
156+
for group in self.param_groups:
157+
if 'step' in group:
158+
group['step'] += 1
159+
else:
160+
group['step'] = 1
161+
162+
beta1, beta2 = group['betas']
163+
164+
bias_correction1: float = 1.0 - beta1 ** group['step']
165+
bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])
166+
167+
should_grokfast: bool = (
168+
group['grokfast'] and group['step'] > group['grokfast_after_step'] and group['grokfast_lamb'] > 0
169+
)
170+
171+
for p in group['params']:
172+
if p.grad is None:
173+
continue
174+
175+
grad = p.grad
176+
if grad.is_sparse:
177+
raise NoSparseGradientError(str(self))
178+
179+
state = self.state[p]
180+
181+
if len(state) == 0:
182+
state['exp_avg'] = torch.zeros_like(p)
183+
state['exp_avg_sq'] = torch.zeros_like(p)
184+
if should_grokfast:
185+
state['grok_exp_avg'] = grad.clone()
186+
187+
self.apply_weight_decay(
188+
p=p,
189+
grad=grad,
190+
lr=group['lr'],
191+
weight_decay=group['weight_decay'],
192+
weight_decouple=group['weight_decouple'],
193+
fixed_decay=group['fixed_decay'],
194+
)
195+
196+
if should_grokfast:
197+
grok_exp_avg = state['grok_exp_avg']
198+
grok_exp_avg.lerp_(grad, weight=1.0 - group['grokfast_alpha'])
199+
200+
grad.add_(grok_exp_avg, alpha=group['grokfast_lamb'])
201+
202+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
203+
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
204+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
205+
206+
de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).clamp_(min=group['eps'])
207+
208+
update = exp_avg.div(bias_correction1).div_(de_nom)
209+
210+
p.add_(update, alpha=-group['lr'])
211+
212+
return loss

0 commit comments

Comments
 (0)