Skip to content

Commit 03f423e

Browse files
committed
feature: implement SPAM optimizer
1 parent ead4aae commit 03f423e

File tree

1 file changed

+275
-0
lines changed

1 file changed

+275
-0
lines changed
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
import math
2+
3+
import torch
4+
5+
from pytorch_optimizer.base.exception import NoSparseGradientError
6+
from pytorch_optimizer.base.optimizer import BaseOptimizer
7+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
8+
9+
10+
class CosineDecay:
11+
r"""Applies cosine decay to a parameter (death_rate), using PyTorch's built-in `CosineAnnealingLR`.
12+
13+
:param death_rate: float. initial value to be decayed.
14+
:param t_max: int. maximum number of iterations for the decay.
15+
:param eta_min: Optional[float]. minimum value of the parameter after decay. defaults to 0.
16+
:param last_epoch: Optional[int]. the index of the last epoch. Defaults to -1.
17+
"""
18+
19+
def __init__(self, death_rate: float, t_max: int, eta_min: float = 0.0, last_epoch: int = -1):
20+
self.sgd = torch.optim.SGD(
21+
torch.nn.ParameterList([torch.nn.Parameter(torch.zeros(1))]),
22+
lr=death_rate,
23+
)
24+
self.cosine_stepper = torch.optim.lr_scheduler.CosineAnnealingLR(self.sgd, t_max + 1, eta_min, last_epoch)
25+
self.T_max = t_max
26+
self.eta_min = eta_min
27+
28+
def step(self, current_step: int) -> None:
29+
r"""One step of the cosine decay scheduler.
30+
31+
:param current_step: int. Current step index.
32+
"""
33+
self.cosine_stepper.step(current_step)
34+
35+
def get_death_rate(self, current_step: int) -> float:
36+
r"""Get the updated rate (death_rate) at the given step.
37+
38+
:param current_step: int. Current step index.
39+
"""
40+
if current_step >= self.T_max:
41+
return self.eta_min
42+
43+
self.step(current_step)
44+
45+
return self.sgd.param_groups[0]['lr']
46+
47+
48+
class SPAM(BaseOptimizer):
49+
r"""Spike-Aware Adam with Momentum Reset for Stable LLM Training.
50+
51+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
52+
:param lr: float. learning rate.
53+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
54+
:param density: float. density parameter. only used for 2d parameters (e.g. Linear).
55+
:param weight_decay: float. weight decay (L2 penalty).
56+
:param warmup_epoch: int: number of epochs to warm up. defaults to 50.
57+
:param threshold: int. threshold for gradient masking. defaults to 5000.
58+
:param grad_accu_steps: int. gradient accumulation steps before threshold-based masking applies. defaults to 20.
59+
:param update_proj_gap: int. update projection gap.
60+
:param eps: float. term added to the denominator to improve numerical stability.
61+
"""
62+
63+
def __init__(
64+
self,
65+
params: PARAMETERS,
66+
lr: float = 1e-3,
67+
betas: BETAS = (0.9, 0.999),
68+
density: float = 1.0,
69+
weight_decay: float = 0.0,
70+
warmup_epoch: int = 150,
71+
threshold: int = 5000,
72+
grad_accu_steps: int = 20,
73+
update_proj_gap: int = 500,
74+
eps: float = 1e-6,
75+
**kwargs,
76+
):
77+
self.validate_learning_rate(lr)
78+
self.validate_betas(betas)
79+
self.validate_non_negative(weight_decay, 'weight_decay')
80+
self.validate_non_negative(warmup_epoch, 'warmup_epoch')
81+
self.validate_non_negative(density, 'density')
82+
self.validate_non_negative(threshold, 'threshold')
83+
self.validate_non_negative(grad_accu_steps, 'grad_accu_steps')
84+
self.validate_non_negative(update_proj_gap, 'update_proj_gap')
85+
self.validate_non_negative(eps, 'eps')
86+
87+
self.density = density
88+
self.warmup_epoch = warmup_epoch
89+
self.threshold = threshold
90+
self.grad_accu_steps = grad_accu_steps
91+
self.update_proj_gap = update_proj_gap
92+
self.warmup = CosineDecay(0.99, warmup_epoch)
93+
94+
defaults: DEFAULTS = {
95+
'lr': lr,
96+
'betas': betas,
97+
'weight_decay': weight_decay,
98+
'eps': eps,
99+
**kwargs,
100+
}
101+
super().__init__(params, defaults)
102+
103+
self.init_masks()
104+
105+
self.state['total_step'] = 0
106+
self.state['current_step'] = warmup_epoch + 1
107+
108+
@staticmethod
109+
def initialize_random_rank_boolean_tensor(m: int, n: int, density: float) -> torch.Tensor:
110+
r"""Create an (m x n) boolean tensor with `density` fraction of True entries.
111+
112+
:param m: int. number of rows.
113+
:param n: int. number of columns.
114+
:param density: float. fraction of True entries. 1.0 means all True.
115+
"""
116+
total_elements: int = m * n
117+
non_zero_count: int = int(density * total_elements)
118+
119+
tensor = torch.zeros((m, n), dtype=torch.bool)
120+
121+
if non_zero_count == 0:
122+
return tensor
123+
124+
indices = torch.randperm(total_elements)[:non_zero_count]
125+
rows, cols = indices // n, indices % n
126+
tensor[rows, cols] = True
127+
128+
return tensor
129+
130+
def update_mask_random(self, density: float, p: torch.Tensor, old_mask: torch.Tensor) -> torch.Tensor:
131+
r"""Update a random mask.
132+
133+
Create a new random mask with the same density, compute overlap ratio with old_mask, and update the EMA for
134+
the overlap region.
135+
136+
:param density: float. fraction of elements to keep.
137+
:param p: torch.Tensor. parameter to which the mask is applied.
138+
:param old_mask: torch.Tensor. previous binary mask.
139+
"""
140+
new_mask: torch.Tensor = torch.rand_like(p) < density
141+
142+
exp_avg = torch.zeros_like(p[new_mask])
143+
exp_avg_sq = torch.zeros_like(p[new_mask])
144+
145+
intersection_mask = new_mask & old_mask
146+
new_intersection_indices = intersection_mask[new_mask]
147+
old_intersection_indices = intersection_mask[old_mask]
148+
149+
state = self.state[p]
150+
exp_avg[new_intersection_indices] = state['exp_avg'][old_intersection_indices]
151+
exp_avg_sq[new_intersection_indices] = state['exp_avg_sq'][old_intersection_indices]
152+
153+
state['exp_avg'] = exp_avg
154+
state['exp_avg_sq'] = exp_avg_sq
155+
156+
return new_mask
157+
158+
def update_masks(self) -> None:
159+
r"""Update masks in each parameter group that has 'density'.
160+
161+
The new mask is selected randomly, and the overlap ratio with the old mask is printed.
162+
"""
163+
for group in self.param_groups:
164+
for p in group['params']:
165+
state = self.state[p]
166+
if 'mask' in state:
167+
new_mask = self.update_mask_random(self.density, p, state['mask'])
168+
state['mask'] = new_mask
169+
p.mask = new_mask
170+
171+
def init_masks(self) -> None:
172+
r"""Initialize random masks for each parameter group that has 'density'."""
173+
for group in self.param_groups:
174+
for p in group['params']:
175+
state = self.state[p]
176+
if p.dim() == 2 and 'mask' not in state:
177+
state['mask'] = self.initialize_random_rank_boolean_tensor(
178+
p.shape[0],
179+
p.shape[1],
180+
density=self.density,
181+
).to(p.device)
182+
183+
def __str__(self) -> str:
184+
return 'SPAM'
185+
186+
@torch.no_grad()
187+
def reset(self):
188+
for group in self.param_groups:
189+
group['step'] = 0
190+
for p in group['params']:
191+
state = self.state[p]
192+
193+
state['exp_avg'] = torch.zeros_like(p)
194+
state['exp_avg_sq'] = torch.zeros_like(p)
195+
196+
@torch.no_grad()
197+
def step(self, closure: CLOSURE = None) -> LOSS:
198+
loss: LOSS = None
199+
if closure is not None:
200+
with torch.enable_grad():
201+
loss = closure()
202+
203+
scale_factor: float = 1.0 - self.warmup.get_death_rate(self.state['current_step'])
204+
205+
for group in self.param_groups:
206+
if 'step' not in group:
207+
group['step'] = 1
208+
else:
209+
group['step'] += 1
210+
211+
beta1, beta2 = group['betas']
212+
213+
bias_correction1: float = self.debias(beta1, group['step'])
214+
bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))
215+
216+
step_size: float = group['lr'] * bias_correction2_sq / bias_correction1
217+
218+
for p in group['params']:
219+
if p.grad is None:
220+
continue
221+
222+
grad = p.grad
223+
if grad.is_sparse:
224+
raise NoSparseGradientError(str(self))
225+
226+
state = self.state[p]
227+
228+
if 'mask' in state:
229+
grad = grad[state['mask']]
230+
231+
if len(state) == 0:
232+
state['exp_avg'] = torch.zeros_like(grad)
233+
state['exp_avg_sq'] = torch.zeros_like(grad)
234+
235+
if (self.state['total_step'] + 1) % self.update_proj_gap == 0:
236+
state['exp_avg'] = torch.zeros_like(grad)
237+
state['exp_avg_sq'] = torch.zeros_like(grad)
238+
239+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
240+
241+
if self.threshold != 0:
242+
current_step: int = self.state['total_step'] + 1
243+
if current_step >= self.grad_accu_steps and (
244+
self.update_proj_gap == 0 or current_step % self.update_proj_gap >= self.grad_accu_steps
245+
):
246+
mask = grad.pow(2) > (self.threshold * exp_avg_sq)
247+
grad[mask].sign_().mul_(torch.sqrt(exp_avg_sq[mask] * self.threshold))
248+
249+
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
250+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
251+
252+
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
253+
254+
if 'mask' in state:
255+
grad_full = torch.zeros_like(p.grad)
256+
grad_full[state['mask']] = exp_avg / de_nom
257+
p.add_(grad_full, alpha=-step_size * scale_factor)
258+
else:
259+
p.addcdiv_(exp_avg, de_nom, value=-step_size * scale_factor)
260+
261+
if group['weight_decay'] > 0:
262+
if 'mask' in state:
263+
p[state['mask']].add_(p[state['mask']], alpha=-group['lr'] * group['weight_decay'])
264+
else:
265+
p.add_(p, alpha=-group['lr'] * group['weight_decay'])
266+
267+
self.state['total_step'] += 1
268+
self.state['current_step'] += 1
269+
270+
if (self.state['total_step'] != 0) and (self.state['total_step'] + 1) % self.update_proj_gap == 0:
271+
self.update_masks()
272+
self.state['current_step'] = 0
273+
self.warmup = CosineDecay(0.99, self.warmup_epoch)
274+
275+
return loss

0 commit comments

Comments
 (0)