Skip to content

Commit b672031

Browse files
committed
feature: implement AdaGC optimizer
1 parent d0dab0f commit b672031

File tree

1 file changed

+139
-0
lines changed

1 file changed

+139
-0
lines changed
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import math
2+
3+
import torch
4+
5+
from pytorch_optimizer.base.exception import NoSparseGradientError
6+
from pytorch_optimizer.base.optimizer import BaseOptimizer
7+
from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
8+
from pytorch_optimizer.optimizer.utils import get_global_gradient_norm
9+
10+
11+
class AdaGC(BaseOptimizer):
12+
r"""Improving Training Stability for Large Language Model Pretraining.
13+
14+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
15+
:param lr: float. learning rate.
16+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
17+
:param beta: float. smoothing coefficient for EMA.
18+
:param lambda_abs: float. absolute clipping threshold to prevent unstable updates from gradient explosions.
19+
:param lambda_rel: float. relative clipping threshold to prevent unstable updates from gradient explosions.
20+
:param warmup_steps: int. warmup steps.
21+
:param weight_decay: float. weight decay (L2 penalty).
22+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
23+
:param fixed_decay: bool. fix weight decay.
24+
:param eps: float. term added to the denominator to improve numerical stability.
25+
"""
26+
27+
def __init__(
28+
self,
29+
params: PARAMETERS,
30+
lr: float = 1e-3,
31+
betas: BETAS = (0.9, 0.999),
32+
beta: float = 0.98,
33+
lambda_abs: float = 1.0,
34+
lambda_rel: float = 1.05,
35+
warmup_steps: int = 100,
36+
weight_decay: float = 1e-1,
37+
weight_decouple: bool = True,
38+
fixed_decay: bool = False,
39+
eps: float = 1e-8,
40+
**kwargs,
41+
):
42+
self.validate_learning_rate(lr)
43+
self.validate_betas(betas)
44+
self.validate_range(beta, 'beta', 0.0, 1.0, '[)')
45+
self.validate_positive(lambda_abs, 'lambda_abs')
46+
self.validate_positive(lambda_rel, 'lambda_rel')
47+
self.validate_non_negative(warmup_steps, 'warmup_steps')
48+
self.validate_non_negative(weight_decay, 'weight_decay')
49+
self.validate_non_negative(eps, 'eps')
50+
51+
defaults: DEFAULTS = {
52+
'lr': lr,
53+
'betas': betas,
54+
'beta': beta,
55+
'lambda_abs': lambda_abs,
56+
'lambda_rel': lambda_rel,
57+
'warmup_steps': warmup_steps,
58+
'weight_decay': weight_decay,
59+
'weight_decouple': weight_decouple,
60+
'fixed_decay': fixed_decay,
61+
'eps': eps,
62+
}
63+
super().__init__(params, defaults)
64+
65+
def __str__(self) -> str:
66+
return 'AdaGC'
67+
68+
@torch.no_grad()
69+
def reset(self):
70+
pass
71+
72+
@torch.no_grad()
73+
def step(self, closure: CLOSURE = None) -> LOSS:
74+
loss: LOSS = None
75+
if closure is not None:
76+
with torch.enable_grad():
77+
loss = closure()
78+
79+
for group in self.param_groups:
80+
if 'step' in group:
81+
group['step'] += 1
82+
else:
83+
group['step'] = 1
84+
85+
beta1, beta2 = group['betas']
86+
87+
bias_correction1: float = self.debias(beta1, group['step'])
88+
bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))
89+
90+
for p in group['params']:
91+
if p.grad is None:
92+
continue
93+
94+
grad = p.grad
95+
if grad.is_sparse:
96+
raise NoSparseGradientError(str(self))
97+
98+
state = self.state[p]
99+
100+
if 'exp_avg' not in state:
101+
state['exp_avg'] = torch.zeros_like(grad)
102+
state['exp_avg_sq'] = torch.zeros_like(grad)
103+
state['gamma'] = torch.empty((1,), device=grad.device, dtype=grad.dtype)
104+
105+
self.apply_weight_decay(
106+
p=p,
107+
grad=grad,
108+
lr=group['lr'],
109+
weight_decay=group['weight_decay'],
110+
weight_decouple=group['weight_decouple'],
111+
fixed_decay=group['fixed_decay'],
112+
)
113+
114+
gamma = state['gamma']
115+
116+
if group['step'] < group['warmup_steps']:
117+
grad_norm = get_global_gradient_norm(self.param_groups).add_(group['eps'])
118+
119+
h_t = min(group['lambda_abs'] / grad_norm, 1.0)
120+
g_hat = grad.mul(h_t)
121+
122+
g_hat_norm = g_hat.norm()
123+
124+
gamma.copy_(g_hat_norm if group['step'] == 1 else min(gamma, g_hat_norm))
125+
else:
126+
h_t = min(group['lambda_rel'] * gamma / grad.norm(), 1.0)
127+
g_hat = grad.mul(h_t)
128+
129+
gamma.mul_(group['beta']).add_(g_hat.norm(), alpha=1.0 - group['beta'])
130+
131+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
132+
exp_avg.mul_(beta1).add_(g_hat, alpha=1.0 - beta1)
133+
exp_avg_sq.mul_(beta2).addcmul_(g_hat, g_hat, value=1.0 - beta2)
134+
135+
update = (exp_avg / bias_correction1) / exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])
136+
137+
p.add_(update, alpha=-group['lr'])
138+
139+
return loss

0 commit comments

Comments
 (0)