Skip to content

Commit ce56167

Browse files
authored
Merge pull request #95 from kozistr/refactor/optimizers
[Feature] Implement & Optimize a few optimizer options
2 parents f6baa63 + 27d6b99 commit ce56167

24 files changed

+234
-168
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.2.0"
3+
version = "2.2.1"
44
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]

pytorch_optimizer/base/optimizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ def validate_update_frequency(update_frequency: int):
9292
if update_frequency < 1:
9393
raise ValueError(f'[-] update_frequency {update_frequency} must be positive')
9494

95+
@staticmethod
96+
def validate_norm(norm: float):
97+
if norm < 0.0:
98+
raise ValueError(f'[-] norm {norm} must be positive')
99+
95100
@abstractmethod
96101
def validate_parameters(self):
97102
raise NotImplementedError

pytorch_optimizer/optimizer/adabelief.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9494
loss = closure()
9595

9696
for group in self.param_groups:
97+
beta1, beta2 = group['betas']
98+
if self.rectify:
99+
n_sma_max: float = 2.0 / (1.0 - beta2) - 1.0
100+
97101
for p in group['params']:
98102
if p.grad is None:
99103
continue
@@ -128,18 +132,16 @@ def step(self, closure: CLOSURE = None) -> LOSS:
128132
exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
129133

130134
state['step'] += 1
131-
beta1, beta2 = group['betas']
132135

133136
bias_correction1 = 1.0 - beta1 ** state['step']
134137
bias_correction2 = 1.0 - beta2 ** state['step']
135138

136139
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
137140
grad_residual = grad - exp_avg
138141
exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1.0 - beta2)
139-
140-
exp_avg_var = exp_avg_var.add_(group['eps'])
142+
exp_avg_var.add_(group['eps'])
141143
if group['amsgrad']:
142-
exp_avg_var = torch.max(state['max_exp_avg_var'], exp_avg_var)
144+
torch.max(state['max_exp_avg_var'], exp_avg_var, out=exp_avg_var)
143145

144146
de_nom = (exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
145147

@@ -155,12 +157,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
155157
else:
156158
buffered[0] = state['step']
157159
beta2_t = beta2 ** state['step']
158-
n_sma_max = 2 / (1 - beta2) - 1
159160
n_sma = n_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
160161
buffered[1] = n_sma
161162

162163
if n_sma >= self.n_sma_threshold:
163-
rt = math.sqrt(
164+
step_size = math.sqrt(
164165
(1 - beta2_t)
165166
* (n_sma - 4)
166167
/ (n_sma_max - 4)
@@ -169,8 +170,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
169170
* n_sma_max
170171
/ (n_sma_max - 2)
171172
)
172-
173-
step_size = rt
174173
if not group['adamd_debias_term']:
175174
step_size /= bias_correction1
176175
elif self.degenerated_to_sgd:

pytorch_optimizer/optimizer/adabound.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9292
loss = closure()
9393

9494
for group, base_lr in zip(self.param_groups, self.base_lrs):
95+
beta1, beta2 = group['betas']
9596
for p in group['params']:
9697
if p.grad is None:
9798
continue
@@ -112,21 +113,18 @@ def step(self, closure: CLOSURE = None) -> LOSS:
112113
state['step'] += 1
113114
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
114115

115-
if group['weight_decay'] != 0:
116+
if group['weight_decay'] > 0.0:
116117
if self.weight_decouple:
117118
p.mul_(
118119
1.0 - (group['weight_decay'] if self.fixed_decay else group['lr'] * group['weight_decay'])
119120
)
120121
else:
121122
grad.add_(p, alpha=group['weight_decay'])
122123

123-
beta1, beta2 = group['betas']
124-
125124
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
126125
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
127-
128126
if group['amsbound']:
129-
exp_avg_sq = torch.max(state['max_exp_avg_sq'], exp_avg_sq)
127+
torch.max(state['max_exp_avg_sq'], exp_avg_sq, out=exp_avg_sq)
130128

131129
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
132130

pytorch_optimizer/optimizer/adai.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8585
exp_avg_sq_hat_sum: float = 0.0
8686

8787
for group in self.param_groups:
88+
_, beta2 = group['betas']
8889
for p in group['params']:
8990
if p.grad is None:
9091
continue
@@ -106,14 +107,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
106107
state['step'] += 1
107108

108109
exp_avg_sq = state['exp_avg_sq']
109-
_, beta2 = group['betas']
110110

111111
if self.use_gc:
112112
grad = centralize_gradient(grad, gc_conv_only=False)
113113

114114
bias_correction2 = 1.0 - beta2 ** state['step']
115115

116-
if group['weight_decay'] != 0:
116+
if group['weight_decay'] > 0.0:
117117
if self.weight_decouple:
118118
p.mul_(1.0 - group['lr'] * group['weight_decay'])
119119
else:
@@ -129,6 +129,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
129129
exp_avg_sq_hat_mean = exp_avg_sq_hat_sum / param_size
130130

131131
for group in self.param_groups:
132+
beta0, beta2 = group['betas']
133+
beta0_dp = math.pow(beta0, 1.0 - group['dampening'])
132134
for p in group['params']:
133135
if p.grad is None:
134136
continue
@@ -138,7 +140,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
138140

139141
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
140142
beta1_prod = state['beta1_prod']
141-
beta0, beta2 = group['betas']
142143

143144
bias_correction2 = 1.0 - beta2 ** state['step']
144145

@@ -152,7 +153,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
152153
bias_correction1 = 1.0 - beta1_prod
153154

154155
exp_avg.mul_(beta1).addcmul_(beta3, grad)
155-
exp_avg_hat = exp_avg / bias_correction1 * math.pow(beta0, 1.0 - group['dampening'])
156+
exp_avg_hat = exp_avg / bias_correction1 * beta0_dp
156157

157158
p.add_(exp_avg_hat, alpha=-group['lr'])
158159

pytorch_optimizer/optimizer/adamp.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8989
loss = closure()
9090

9191
for group in self.param_groups:
92+
beta1, beta2 = group['betas']
9293
for p in group['params']:
9394
if p.grad is None:
9495
continue
@@ -103,10 +104,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
103104
state['exp_avg'] = torch.zeros_like(p)
104105
state['exp_avg_sq'] = torch.zeros_like(p)
105106

106-
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
107-
108107
state['step'] += 1
109-
beta1, beta2 = group['betas']
108+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
110109

111110
bias_correction1 = 1.0 - beta1 ** state['step']
112111
bias_correction2 = 1.0 - beta2 ** state['step']
@@ -117,12 +116,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
117116
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
118117
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
119118

120-
de_nom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
119+
inv_de_nom = 1.0 / (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
121120

121+
perturb = exp_avg.clone()
122122
if group['nesterov']:
123-
perturb = (beta1 * exp_avg + (1.0 - beta1) * grad) / de_nom
123+
# perturb = beta1 * exp_avg + (1.0 - beta1) * grad / de_nom
124+
perturb.mul_(beta1).addcmul_(grad, inv_de_nom, value=1.0 - beta1)
124125
else:
125-
perturb = exp_avg / de_nom
126+
perturb.mul_(inv_de_nom)
126127

127128
wd_ratio: float = 1
128129
if len(p.shape) > 1:
@@ -135,7 +136,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
135136
group['eps'],
136137
)
137138

138-
if group['weight_decay'] > 0:
139+
if group['weight_decay'] > 0.0:
139140
p.mul_(1.0 - group['lr'] * group['weight_decay'] * wd_ratio)
140141

141142
step_size = group['lr']

pytorch_optimizer/optimizer/adan.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
from typing import Union
23

34
import torch
45
from torch.optim.optimizer import Optimizer
@@ -17,6 +18,7 @@ class Adan(Optimizer, BaseOptimizer):
1718
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
1819
:param weight_decay: float. weight decay (L2 penalty).
1920
:param weight_decouple: bool. decoupled weight decay.
21+
:param max_grad_norm: float. max gradient norm to clip.
2022
:param use_gc: bool. use gradient centralization.
2123
:param eps: float. term added to the denominator to improve numerical stability.
2224
"""
@@ -28,13 +30,15 @@ def __init__(
2830
betas: BETAS = (0.98, 0.92, 0.99),
2931
weight_decay: float = 0.0,
3032
weight_decouple: bool = False,
33+
max_grad_norm: float = 0.0,
3134
use_gc: bool = False,
3235
eps: float = 1e-8,
3336
):
3437
self.lr = lr
3538
self.betas = betas
3639
self.weight_decay = weight_decay
3740
self.weight_decouple = weight_decouple
41+
self.max_grad_norm = max_grad_norm
3842
self.use_gc = use_gc
3943
self.eps = eps
4044

@@ -46,6 +50,7 @@ def __init__(
4650
eps=eps,
4751
weight_decay=weight_decay,
4852
weight_decouple=weight_decouple,
53+
max_grad_norm=max_grad_norm,
4954
)
5055
super().__init__(params, defaults)
5156

@@ -54,6 +59,7 @@ def validate_parameters(self):
5459
self.validate_betas(self.betas)
5560
self.validate_weight_decay(self.weight_decay)
5661
self.validate_epsilon(self.eps)
62+
self.validate_norm(self.max_grad_norm)
5763

5864
@property
5965
def __name__(self) -> str:
@@ -62,23 +68,54 @@ def __name__(self) -> str:
6268
@torch.no_grad()
6369
def reset(self):
6470
for group in self.param_groups:
71+
group['step'] = 0
6572
for p in group['params']:
6673
state = self.state[p]
6774

68-
state['step'] = 0
6975
state['exp_avg'] = torch.zeros_like(p)
7076
state['exp_avg_diff'] = torch.zeros_like(p)
7177
state['exp_avg_nest'] = torch.zeros_like(p)
7278
state['previous_grad'] = torch.zeros_like(p)
7379

80+
@torch.no_grad()
81+
def get_global_gradient_norm(self) -> Union[torch.Tensor, float]:
82+
if self.defaults['max_grad_norm'] == 0.0:
83+
return 1.0
84+
85+
device = self.param_groups[0]['params'][0].device
86+
87+
global_grad_norm = torch.zeros(1, device=device)
88+
max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device)
89+
90+
for group in self.param_groups:
91+
for p in group['params']:
92+
if p.grad is not None:
93+
global_grad_norm.add_(torch.linalg.norm(p.grad).pow(2))
94+
95+
global_grad_norm = torch.sqrt(global_grad_norm)
96+
97+
return torch.clamp(max_grad_norm / (global_grad_norm + self.eps), max=1.0)
98+
7499
@torch.no_grad()
75100
def step(self, closure: CLOSURE = None) -> LOSS:
76101
loss: LOSS = None
77102
if closure is not None:
78103
with torch.enable_grad():
79104
loss = closure()
80105

106+
clip_global_grad_norm = self.get_global_gradient_norm()
107+
81108
for group in self.param_groups:
109+
if 'step' in group:
110+
group['step'] += 1
111+
else:
112+
group['step'] = 1
113+
114+
beta1, beta2, beta3 = group['betas']
115+
bias_correction1 = 1.0 - beta1 ** group['step']
116+
bias_correction2 = 1.0 - beta2 ** group['step']
117+
bias_correction3_sq = math.sqrt(1.0 - beta3 ** group['step'])
118+
82119
for p in group['params']:
83120
if p.grad is None:
84121
continue
@@ -89,35 +126,28 @@ def step(self, closure: CLOSURE = None) -> LOSS:
89126

90127
state = self.state[p]
91128
if len(state) == 0:
92-
state['step'] = 0
93129
state['exp_avg'] = torch.zeros_like(p)
94130
state['exp_avg_diff'] = torch.zeros_like(p)
95131
state['exp_avg_nest'] = torch.zeros_like(p)
96-
state['previous_grad'] = torch.zeros_like(p)
132+
state['previous_grad'] = grad.clone()
97133

98134
exp_avg, exp_avg_diff, exp_avg_nest = state['exp_avg'], state['exp_avg_diff'], state['exp_avg_nest']
99-
prev_grad = state['previous_grad']
100-
101-
state['step'] += 1
102-
beta1, beta2, beta3 = group['betas']
103135

104-
bias_correction1 = 1.0 - beta1 ** state['step']
105-
bias_correction2 = 1.0 - beta2 ** state['step']
106-
bias_correction3 = 1.0 - beta3 ** state['step']
136+
grad.mul_(clip_global_grad_norm)
107137

108138
if self.use_gc:
109139
grad = centralize_gradient(grad, gc_conv_only=False)
110140

111-
grad_diff = grad - prev_grad
112-
state['previous_grad'] = grad.clone()
141+
grad_diff = grad - state['previous_grad']
142+
state['previous_grad'].copy_(grad)
113143

114144
update = grad + beta2 * grad_diff
115145

116146
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
117147
exp_avg_diff.mul_(beta2).add_(grad_diff, alpha=1.0 - beta2)
118148
exp_avg_nest.mul_(beta3).addcmul_(update, update, value=1.0 - beta3)
119149

120-
de_nom = (exp_avg_nest.sqrt_() / math.sqrt(bias_correction3)).add_(self.eps)
150+
de_nom = (exp_avg_nest.sqrt_() / bias_correction3_sq).add_(self.eps)
121151
perturb = (exp_avg / bias_correction1 + beta2 * exp_avg_diff / bias_correction2).div_(de_nom)
122152

123153
if group['weight_decouple']:

pytorch_optimizer/optimizer/adapnm.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8484
loss = closure()
8585

8686
for group in self.param_groups:
87+
beta1, beta2, beta3 = group['betas']
88+
noise_norm = math.sqrt((1 + beta3) ** 2 + beta3 ** 2) # fmt: skip
8789
for p in group['params']:
8890
if p.grad is None:
8991
continue
@@ -107,7 +109,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
107109
state['max_exp_avg_sq'] = torch.zeros_like(p)
108110

109111
state['step'] += 1
110-
beta1, beta2, beta3 = group['betas']
111112

112113
bias_correction1 = 1 - beta1 ** state['step']
113114
bias_correction2 = 1 - beta2 ** state['step']
@@ -120,18 +121,16 @@ def step(self, closure: CLOSURE = None) -> LOSS:
120121

121122
exp_avg.mul_(beta1 ** 2).add_(grad, alpha=1 - beta1 ** 2) # fmt: skip
122123
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
123-
124124
if group['amsgrad']:
125-
exp_avg_sq = torch.max(state['max_exp_avg_sq'], exp_avg_sq)
125+
torch.max(state['max_exp_avg_sq'], exp_avg_sq, out=exp_avg_sq)
126126

127-
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
127+
de_nom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
128128

129129
step_size = group['lr']
130130
if not group['adamd_debias_term']:
131131
step_size /= bias_correction1
132132

133-
noise_norm = math.sqrt((1 + beta3) ** 2 + beta3 ** 2) # fmt: skip
134-
pn_momentum = exp_avg.mul(1 + beta3).add(neg_exp_avg, alpha=-beta3).mul(1.0 / noise_norm)
135-
p.addcdiv_(pn_momentum, denom, value=-step_size)
133+
pn_momentum = exp_avg.mul(1.0 + beta3).add(neg_exp_avg, alpha=-beta3).mul(1.0 / noise_norm)
134+
p.addcdiv_(pn_momentum, de_nom, value=-step_size)
136135

137136
return loss

0 commit comments

Comments
 (0)