Skip to content

Commit d11bd29

Browse files
committed
feature: AdamS optimizer
1 parent 62ff084 commit d11bd29

File tree

1 file changed

+159
-0
lines changed

1 file changed

+159
-0
lines changed
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import math
2+
3+
import torch
4+
from torch.optim.optimizer import Optimizer
5+
6+
from pytorch_optimizer.base.exception import NoSparseGradientError, ZeroParameterSizeError
7+
from pytorch_optimizer.base.optimizer import BaseOptimizer
8+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
9+
from pytorch_optimizer.optimizer.gc import centralize_gradient
10+
11+
12+
class AdamS(Optimizer, BaseOptimizer):
13+
r"""Adam with stable weight decay.
14+
15+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
16+
:param lr: float. learning rate.
17+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
18+
:param weight_decay: float. weight decay (L2 penalty).
19+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
20+
:param amsgrad: bool. whether to use the AMSGrad variant of this algorithm from the paper.
21+
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training.
22+
:param eps: float. term added to the denominator to improve numerical stability.
23+
"""
24+
25+
def __init__(
26+
self,
27+
params: PARAMETERS,
28+
lr: float = 1e-3,
29+
betas: BETAS = (0.1, 0.99),
30+
weight_decay: float = 1e-4,
31+
weight_decouple: bool = False,
32+
amsgrad: bool = False,
33+
adamd_debias_term: bool = False,
34+
eps: float = 1e-8,
35+
):
36+
self.lr = lr
37+
self.betas = betas
38+
self.weight_decay = weight_decay
39+
self.weight_decouple = weight_decouple
40+
self.amsgrad = amsgrad
41+
self.adamd_debias_term = adamd_debias_term
42+
self.eps = eps
43+
44+
self.validate_parameters()
45+
46+
defaults: DEFAULTS = {
47+
'lr': lr,
48+
'betas': betas,
49+
'weight_decay': weight_decay,
50+
'eps': eps,
51+
}
52+
super().__init__(params, defaults)
53+
54+
def validate_parameters(self):
55+
self.validate_learning_rate(self.lr)
56+
self.validate_betas(self.betas)
57+
self.validate_weight_decay(self.weight_decay)
58+
self.validate_epsilon(self.eps)
59+
60+
@property
61+
def __str__(self) -> str:
62+
return 'AdamS'
63+
64+
@torch.no_grad()
65+
def reset(self):
66+
for group in self.param_groups:
67+
for p in group['params']:
68+
state = self.state[p]
69+
70+
state['step'] = 0
71+
state['exp_avg'] = torch.zeros_like(p)
72+
state['exp_avg_sq'] = torch.zeros_like(p)
73+
74+
@torch.no_grad()
75+
def step(self, closure: CLOSURE = None) -> LOSS:
76+
loss: LOSS = None
77+
if closure is not None:
78+
with torch.enable_grad():
79+
loss = closure()
80+
81+
param_size: int = 0
82+
exp_avg_sq_hat_sum: float = 0.0
83+
84+
for group in self.param_groups:
85+
beta1, beta2 = group['betas']
86+
for p in group['params']:
87+
if p.grad is None:
88+
continue
89+
90+
grad = p.grad
91+
if grad.is_sparse:
92+
raise NoSparseGradientError(self.__str__)
93+
94+
param_size += p.numel()
95+
96+
state = self.state[p]
97+
98+
if len(state) == 0:
99+
state['step'] = 0
100+
state['exp_avg'] = torch.zeros_like(p)
101+
state['exp_avg_sq'] = torch.zeros_like(p)
102+
if self.amsgrad:
103+
state['max_exp_avg_sq'] = torch.zeros_like(p)
104+
105+
state['step'] += 1
106+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
107+
108+
bias_correction2 = 1.0 - beta2 ** state['step']
109+
110+
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
111+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
112+
113+
if self.amsgrad:
114+
max_exp_avg_sq = state['max_exp_avg_sq']
115+
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
116+
exp_avg_sq_hat = max_exp_avg_sq
117+
else:
118+
exp_avg_sq_hat = exp_avg_sq
119+
120+
exp_avg_sq_hat_sum += exp_avg_sq_hat.sum() / bias_correction2
121+
122+
if param_size == 0:
123+
raise ZeroParameterSizeError()
124+
125+
exp_avg_sq_hat_mean = math.sqrt(exp_avg_sq_hat_sum / param_size)
126+
127+
for group in self.param_groups:
128+
beta1, beta2 = group['betas']
129+
for p in group['params']:
130+
if p.grad is None:
131+
continue
132+
133+
grad = p.grad
134+
state = self.state[p]
135+
136+
if group['weight_decay'] > 0.0:
137+
if self.weight_decouple:
138+
p.mul_(1.0 - group['lr'] * group['weight_decay'])
139+
else:
140+
grad.add_(p, alpha=group['weight_decay'])
141+
142+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
143+
144+
bias_correction1 = 1.0 - beta1 ** state['step']
145+
bias_correction2 = 1.0 - beta2 ** state['step']
146+
147+
if self.amsgrad:
148+
max_exp_avg_sq = state['max_exp_avg_sq']
149+
exp_avg_sq_hat = max_exp_avg_sq
150+
else:
151+
exp_avg_sq_hat = exp_avg_sq
152+
exp_avg_sq_hat.div_(bias_correction2)
153+
154+
de_nom = exp_avg_sq_hat.sqrt().add(group['eps'])
155+
156+
step_size = group['lr'] if self.adamd_debias_term else group['lr'] / bias_correction1
157+
p.addcdiv_(exp_avg, de_nom, alpha=-step_size)
158+
159+
return loss

0 commit comments

Comments
 (0)