Skip to content

Commit 337bb0e

Browse files
committed
feature: implement AdamMini optimizer
1 parent 2dfa980 commit 337bb0e

File tree

1 file changed

+353
-0
lines changed

1 file changed

+353
-0
lines changed
Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
import math
2+
from typing import Optional
3+
4+
import torch
5+
from torch import distributed as dist
6+
from torch import nn
7+
from torch.optim.optimizer import Optimizer
8+
9+
from pytorch_optimizer.base.exception import NoSparseGradientError
10+
from pytorch_optimizer.base.optimizer import BaseOptimizer
11+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS
12+
13+
14+
class AdamMini(Optimizer, BaseOptimizer):
15+
r"""Use Fewer Learning Rates To Gain More.
16+
17+
:param model: nn.Module. model instance.
18+
:param model_sharding: bool. set to True if you are using model parallelism with more than 1 GPU, including FSDP
19+
and zero_1, 2, 3 in Deepspeed. Set to False if otherwise.
20+
:param lr: float. learning rate.
21+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
22+
:param weight_decay: float. weight decay (L2 penalty).
23+
:param num_embeds: int. number of embedding dimensions. could be unspecified if you are training non-transformer
24+
models.
25+
:param num_heads: int. number of attention heads. could be unspecified if you are training non-transformer models.
26+
:param num_query_groups: Optional[int]. number of query groups in Group Query Attention (GQA). if not specified, it
27+
will be equal to num_heads. could be unspecified if you are training non-transformer models.
28+
:param eps: float. term added to the denominator to improve numerical stability.
29+
"""
30+
31+
def __init__(
32+
self,
33+
model: nn.Module,
34+
lr: float = 1.0,
35+
betas: BETAS = (0.9, 0.999),
36+
weight_decay: float = 0.1,
37+
model_sharding: bool = False,
38+
num_embeds: int = 2048,
39+
num_heads: int = 32,
40+
num_query_groups: Optional[int] = None,
41+
eps: float = 1e-8,
42+
):
43+
self.validate_learning_rate(lr)
44+
self.validate_betas(betas)
45+
self.validate_non_negative(weight_decay, 'weight_decay')
46+
self.validate_non_negative(num_embeds, 'num_embeds')
47+
self.validate_non_negative(num_heads, 'num_heads')
48+
self.validate_non_negative(eps, 'eps')
49+
50+
self.num_query_groups: int = num_query_groups if num_query_groups is not None else num_embeds
51+
self.validate_mod(num_embeds, self.num_query_groups)
52+
53+
self.world_size: int = torch.cuda.device_count()
54+
55+
self.model = model
56+
self.model_sharding = model_sharding
57+
self.num_embeds = num_embeds
58+
self.num_heads = num_heads
59+
60+
groups = self.get_optimizer_groups(weight_decay)
61+
62+
defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'eps': eps}
63+
super().__init__(groups, defaults)
64+
65+
def __str__(self) -> str:
66+
return 'AdamMini'
67+
68+
def get_optimizer_groups(self, weight_decay: float):
69+
groups = []
70+
for name, param in self.model.named_parameters():
71+
if not param.requires_grad:
72+
continue
73+
74+
group = {
75+
'name': name,
76+
'params': param,
77+
'weight_decay': 0.0 if ('norm' in name or 'ln_f' in name) else weight_decay,
78+
}
79+
80+
if (
81+
'self_attn.k_proj.weight' in name
82+
or 'self_attn.q_proj.weight' in name
83+
or 'attn.wq.weight' in name
84+
or 'attn.wk.weight' in name
85+
):
86+
group['parameter_per_head'] = self.num_embeds * self.num_embeds // self.num_heads
87+
88+
if 'attn.attn.weight' in name or 'attn.qkv.weight' in name:
89+
group['n_head'] = self.num_heads
90+
group['q_per_kv'] = self.num_embeds // self.num_query_groups
91+
92+
groups.append(group)
93+
94+
return groups
95+
96+
@torch.no_grad()
97+
def reset(self):
98+
for group in self.param_groups:
99+
group['step'] = 0
100+
for p in group['params']:
101+
state = self.state[p]
102+
103+
state['m'] = torch.zeros_like(p, dtype=torch.float32)
104+
state['v'] = torch.zeros_like(p, dtype=torch.float32)
105+
106+
@staticmethod
107+
def step_embed(
108+
p,
109+
grad,
110+
state,
111+
lr: float,
112+
beta1: float,
113+
beta2: float,
114+
bias_correction1: float,
115+
bias_correction2_sq: float,
116+
eps: float,
117+
) -> None:
118+
if len(state) == 0:
119+
state['m'] = torch.zeros_like(p, dtype=torch.float32)
120+
state['v'] = torch.zeros_like(p, dtype=torch.float32)
121+
122+
m, v = state['m'], state['v']
123+
124+
m.lerp_(grad, weight=1.0 - beta1)
125+
v.mul_(beta2).addcmul_(grad, grad.conj(), value=1.0 - beta2)
126+
127+
h = (v.sqrt() / bias_correction2_sq).add_(eps)
128+
129+
p.addcdiv_(m, h, value=-lr / bias_correction1)
130+
131+
@staticmethod
132+
def step_attn_proj(
133+
p,
134+
grad,
135+
state,
136+
parameter_per_head: int,
137+
lr: float,
138+
beta1: float,
139+
beta2: float,
140+
bias_correction1: float,
141+
bias_correction2_sq: float,
142+
eps: float,
143+
) -> None:
144+
if len(state) == 0:
145+
state['m'] = torch.zeros_like(p, dtype=torch.float32).view(-1, parameter_per_head)
146+
state['head'] = state['m'].shape[0]
147+
state['v_mean'] = torch.zeros(state['head'], device=state['m'].device)
148+
149+
m, v = state['m'], state['v_mean']
150+
151+
head: int = state['head']
152+
grad = grad.view(head, parameter_per_head)
153+
154+
m.lerp_(grad, weight=1.0 - beta1)
155+
156+
tmp_lr = torch.mean(grad * grad, dim=1).to(m.device)
157+
v.mul_(beta2).add_(tmp_lr, alpha=1.0 - beta2)
158+
159+
h = (v.sqrt() / bias_correction2_sq).add_(eps)
160+
161+
update = (1 / (h * bias_correction1)).view(head, 1).mul_(m)
162+
163+
if p.dim() > 1:
164+
d0, d1 = p.size()
165+
update = update.view(d0, d1)
166+
else:
167+
update = update.view(-1)
168+
169+
p.add_(update, alpha=-lr)
170+
171+
@staticmethod
172+
def step_attn(
173+
p,
174+
grad,
175+
state,
176+
num_heads: int,
177+
q_per_kv: int,
178+
lr: float,
179+
beta1: float,
180+
beta2: float,
181+
bias_correction1: float,
182+
bias_correction2_sq: float,
183+
eps: float,
184+
) -> None:
185+
if len(state) == 0:
186+
state['m'] = torch.zeros_like(p, dtype=torch.float32).view(num_heads, q_per_kv + 2, -1)
187+
state['v_mean'] = torch.zeros(num_heads, q_per_kv + 2, device=state['m'].device)
188+
189+
m, v = state['m'], state['v_mean']
190+
191+
grad = grad.view(num_heads, q_per_kv + 2, -1)
192+
193+
m.lerp_(grad, weight=1.0 - beta1)
194+
195+
tmp_lr = torch.mean(grad * grad, dim=2).to(m.device)
196+
v.mul_(beta2).add_(tmp_lr, alpha=1.0 - beta2)
197+
198+
h = (v.sqrt() / bias_correction2_sq).add_(eps)
199+
200+
update = (1 / (h * bias_correction1)).view(num_heads, q_per_kv + 2, -1).mul_(m)
201+
202+
if p.dim() > 1:
203+
d0, d1 = p.size()
204+
update = update.view(d0, d1)
205+
else:
206+
update = update.view(-1)
207+
208+
p.add_(update, alpha=-lr)
209+
210+
def step_lefts(
211+
self,
212+
p,
213+
grad,
214+
state,
215+
lr: float,
216+
beta1: float,
217+
beta2: float,
218+
bias_correction1: float,
219+
bias_correction2_sq: float,
220+
eps: float,
221+
) -> None: # pragma: no cover
222+
if len(state) == 0:
223+
dim = torch.tensor(p.numel(), device=p.device, dtype=torch.float32)
224+
225+
reduced: bool = False
226+
if self.model_sharding and self.world_size > 1:
227+
tensor_list = [torch.zeros_like(dim) for _ in range(self.world_size)]
228+
dist.all_gather(tensor_list, dim)
229+
230+
s, dim = 0, 0
231+
for d in tensor_list:
232+
if d > 0:
233+
s += 1
234+
dim += d
235+
236+
if s >= 2:
237+
reduced = True
238+
239+
state['m'] = torch.zeros_like(p, dtype=torch.float32)
240+
state['v_mean'] = torch.tensor(0.0, device=state['m'].device)
241+
state['dimension'] = dim
242+
state['reduced'] = reduced
243+
244+
tmp_lr = torch.sum(grad * grad)
245+
246+
if state['reduced']:
247+
dist.all_reduce(tmp_lr, op=dist.ReduceOp.SUM)
248+
249+
tmp_lr.div_(state['dim'])
250+
251+
m, v = state['m'], state['v_mean']
252+
253+
m.lerp_(grad, weight=1.0 - beta1)
254+
v.mul_(beta2).add_(tmp_lr, value=1.0 - beta2)
255+
256+
h = (v.sqrt() / bias_correction2_sq).add_(eps)
257+
258+
update = 1 / (bias_correction1 * h).mul_(m)
259+
260+
p.add_(update, alpha=-lr)
261+
262+
@torch.no_grad()
263+
def step(self, closure: CLOSURE = None) -> LOSS:
264+
loss: LOSS = None
265+
if closure is not None:
266+
with torch.enable_grad():
267+
loss = closure()
268+
269+
for group in self.param_groups:
270+
if 'step' in group:
271+
group['step'] += 1
272+
else:
273+
group['step'] = 1
274+
275+
name = group['name']
276+
277+
beta1, beta2 = group['betas']
278+
279+
bias_correction1: float = 1.0 - beta1 ** group['step']
280+
bias_correction2: float = 1.0 - beta2 ** group['step']
281+
bias_correction2_sq: float = math.sqrt(bias_correction2)
282+
283+
for p in group['params']:
284+
if p.grad is None:
285+
continue
286+
287+
grad = p.grad
288+
if grad.is_sparse:
289+
raise NoSparseGradientError(str(self))
290+
291+
grad = grad.to(torch.float32)
292+
293+
state = self.state[p]
294+
295+
self.apply_weight_decay(
296+
p=p,
297+
grad=grad,
298+
lr=group['lr'],
299+
weight_decay=group['weight_decay'],
300+
weight_decouple=True,
301+
fixed_decay=False,
302+
)
303+
304+
if 'embed_tokens' in name or 'wte' in name or 'lm_head' in name:
305+
self.step_embed(
306+
p, grad, state, group['lr'], beta1, beta2, bias_correction1, bias_correction2_sq, group['eps']
307+
)
308+
elif (
309+
'self_attn.k_proj.weight' in name
310+
or 'self_attn.q_proj.weight' in name
311+
or 'attn.wq.weight' in name
312+
or 'attn.wk.weight' in name
313+
):
314+
self.step_attn_proj(
315+
p,
316+
grad,
317+
state,
318+
group['parameter_per_head'],
319+
group['lr'],
320+
beta1,
321+
beta2,
322+
bias_correction1,
323+
bias_correction2_sq,
324+
group['eps'],
325+
)
326+
elif 'attn.attn.weight' in name or 'attn.qkv.weight' in name:
327+
self.step_attn(
328+
p,
329+
grad,
330+
state,
331+
group['n_head'],
332+
group['q_per_kv'],
333+
group['lr'],
334+
beta1,
335+
beta2,
336+
bias_correction1,
337+
bias_correction2_sq,
338+
group['eps'],
339+
)
340+
else:
341+
self.step_lefts(
342+
p,
343+
grad,
344+
state,
345+
group['lr'],
346+
beta1,
347+
beta2,
348+
bias_correction1,
349+
bias_correction2_sq,
350+
group['eps'],
351+
)
352+
353+
return loss

0 commit comments

Comments
 (0)