Skip to content

Commit 1524b89

Browse files
authored
Merge pull request #294 from kozistr/update/codes
[Feature] Cautious optimizer, improve the stability of ADOPT optimizer, a new projector type `random` for `GaLore` optimizer
2 parents 131314b + db82a58 commit 1524b89

File tree

10 files changed

+85
-16
lines changed

10 files changed

+85
-16
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
183183
| SOAP | *Improving and Stabilizing Shampoo using Adam* | [github](https://github.com/nikhilvyas/SOAP) | <https://arxiv.org/abs/2409.11321> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240911321V/exportcitation) |
184184
| ADOPT | *Modified Adam Can Converge with Any β2 with the Optimal Rate* | [github](https://github.com/iShohei220/adopt) | <https://arxiv.org/abs/2411.02853> | [cite](https://github.com/iShohei220/adopt?tab=readme-ov-file#citation) |
185185
| FTRL | *Follow The Regularized Leader* | | <https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf> | |
186+
| Cautious | *Improving Training with One Line of Code* | [github](https://github.com/kyleliang919/C-Optim) | <https://arxiv.org/pdf/2411.16085v1> | [cite](https://github.com/kyleliang919/C-Optim?tab=readme-ov-file#citation) |
186187

187188
## Supported LR Scheduler
188189

docs/changelogs/v3.3.0.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
* [Modified Adam Can Converge with Any β2 with the Optimal Rate](https://arxiv.org/abs/2411.02853)
99
* Implement `FTRL` optimizer. (#291)
1010
* [Follow The Regularized Leader](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf)
11+
* Implement `Cautious optimizer` feature. (#294)
12+
* [Improving Training with One Line of Code](https://arxiv.org/pdf/2411.16085v1)
13+
* you can use it by setting `cautious=True` for `Lion`, `AdaFactor` and `AdEMAMix` optimizers.
14+
* Improve the stability of `ADOPT` optimizer. (#294)
15+
* [Note](https://github.com/iShohei220/adopt?tab=readme-ov-file#update-on-nov-22-2024)
16+
* Support a new projection type `random` for `GaLoreProjector`. (#294)
1117

1218
### Refactor
1319

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
183183
| SOAP | *Improving and Stabilizing Shampoo using Adam* | [github](https://github.com/nikhilvyas/SOAP) | <https://arxiv.org/abs/2409.11321> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240911321V/exportcitation) |
184184
| ADOPT | *Modified Adam Can Converge with Any β2 with the Optimal Rate* | [github](https://github.com/iShohei220/adopt) | <https://arxiv.org/abs/2411.02853> | [cite](https://github.com/iShohei220/adopt?tab=readme-ov-file#citation) |
185185
| FTRL | *Follow The Regularized Leader* | | <https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf> | |
186+
| Cautious | *Improving Training with One Line of Code* | [github](https://github.com/kyleliang919/C-Optim) | <https://arxiv.org/pdf/2411.16085v1> | [cite](https://github.com/kyleliang919/C-Optim?tab=readme-ov-file#citation) |
186187

187188
## Supported LR Scheduler
188189

pytorch_optimizer/base/optimizer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,17 @@ def approximate_sq_grad(
255255
c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
256256
torch.mul(r_factor, c_factor, out=output)
257257

258+
@staticmethod
259+
def apply_cautious(update: torch.Tensor, grad: torch.Tensor) -> None:
260+
r"""Apply the Cautious Optimizer feature.
261+
262+
:param update: torch.Tensor. update. it'll be masked in in-place manner.
263+
:param grad: torch.Tensor. gradient.
264+
"""
265+
mask = (update * grad > 0).to(grad.dtype)
266+
mask.mul_(mask.numel() / (mask.sum() + 1))
267+
update.mul_(mask)
268+
258269
@staticmethod
259270
def validate_range(x: float, name: str, low: float, high: float, range_type: str = '[)') -> None:
260271
if range_type == '[)' and not low <= x < high:

pytorch_optimizer/optimizer/adafactor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class AdaFactor(BaseOptimizer):
3030
:param momentum_dtype: torch.dtype. type of momentum variable. In VIT paper observed that storing momentum in
3131
half-precision (bfloat16 type) does not affect training dynamics and has no effect on the outcome while
3232
reducing optimize overhead from 2-fold to 1.5-fold.
33+
:param cautious: bool. whether to use the Cautious variant.
3334
"""
3435

3536
def __init__(
@@ -49,6 +50,7 @@ def __init__(
4950
eps1: float = 1e-30,
5051
eps2: float = 1e-3,
5152
momentum_dtype: torch.dtype = torch.bfloat16,
53+
cautious: bool = False,
5254
**kwargs,
5355
):
5456
self.validate_learning_rate(lr)
@@ -62,6 +64,7 @@ def __init__(
6264
self.eps1 = eps1
6365
self.eps2 = eps2
6466
self.momentum_dtype = momentum_dtype
67+
self.cautious = cautious
6568

6669
defaults: DEFAULTS = {
6770
'lr': lr,
@@ -214,7 +217,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
214217
exp_avg = state['exp_avg']
215218
exp_avg.mul_(beta1).add_(update, alpha=1.0 - beta1)
216219

217-
update = exp_avg
220+
update = exp_avg.clone()
221+
if self.cautious:
222+
self.apply_cautious(update, grad)
218223

219224
self.apply_weight_decay(
220225
p=p,

pytorch_optimizer/optimizer/ademamix.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class AdEMAMix(BaseOptimizer):
1919
:param fixed_decay: bool. fix weight decay.
2020
:param alpha: float. usually between 4 and 10 would work well.
2121
:param t_alpha_beta3: Optional[float]. total number of iterations is preferred when needed.
22+
:param cautious: bool. whether to use cautious feature.
2223
:param eps: float. term added to the denominator to improve numerical stability.
2324
"""
2425

@@ -32,6 +33,7 @@ def __init__(
3233
fixed_decay: bool = False,
3334
alpha: float = 5.0,
3435
t_alpha_beta3: Optional[float] = None,
36+
cautious: bool = False,
3537
eps: float = 1e-8,
3638
**kwargs,
3739
):
@@ -42,6 +44,8 @@ def __init__(
4244
self.validate_non_negative(weight_decay, 'weight_decay')
4345
self.validate_non_negative(eps, 'eps')
4446

47+
self.cautious = cautious
48+
4549
defaults: DEFAULTS = {
4650
'lr': lr,
4751
'betas': betas,
@@ -71,9 +75,7 @@ def reset(self):
7175

7276
@staticmethod
7377
def schedule_alpha(t_alpha_beta3: Optional[float], step: int, alpha: float) -> float:
74-
if t_alpha_beta3 is None:
75-
return alpha
76-
return min(step * alpha / t_alpha_beta3, alpha)
78+
return alpha if t_alpha_beta3 is None else min(step * alpha / t_alpha_beta3, alpha)
7779

7880
@staticmethod
7981
def schedule_beta3(t_alpha_beta3: Optional[float], step: int, beta1: float, beta3: float) -> float:
@@ -107,6 +109,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
107109
bias_correction1: float = self.debias(beta1, group['step'])
108110
bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))
109111

112+
step_size: float = group['lr'] / bias_correction1
113+
110114
alpha_t: float = self.schedule_alpha(group['t_alpha_beta3'], group['step'], group['alpha'])
111115
beta3_t: float = self.schedule_beta3(group['t_alpha_beta3'], group['step'], beta1, beta3)
112116

@@ -140,10 +144,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
140144
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
141145
exp_avg_slow.mul_(beta3_t).add_(grad, alpha=1.0 - beta3_t)
142146

143-
de_nom = (exp_avg_sq.sqrt() / bias_correction2_sq).add_(group['eps'])
147+
de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])
144148

145-
step_size = group['lr'] / bias_correction1
149+
update = (exp_avg + alpha_t * exp_avg_slow).div_(de_nom)
150+
if self.cautious:
151+
self.apply_cautious(update, grad)
146152

147-
p.addcdiv_(exp_avg + alpha_t * exp_avg_slow, de_nom, value=-step_size)
153+
p.add_(update, alpha=-step_size)
148154

149155
return loss

pytorch_optimizer/optimizer/adopt.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import math
2+
from typing import Callable, Optional
3+
14
import torch
25

36
from pytorch_optimizer.base.exception import NoSparseGradientError
@@ -22,6 +25,7 @@ def __init__(
2225
params: PARAMETERS,
2326
lr: float = 1e-3,
2427
betas: BETAS = (0.9, 0.9999),
28+
clip_lambda: Optional[Callable[[float], float]] = lambda step: math.pow(step, 0.25),
2529
weight_decay: float = 0.0,
2630
weight_decouple: bool = False,
2731
fixed_decay: bool = False,
@@ -33,6 +37,8 @@ def __init__(
3337
self.validate_non_negative(weight_decay, 'weight_decay')
3438
self.validate_non_negative(eps, 'eps')
3539

40+
self.clip_lambda = clip_lambda
41+
3642
defaults: DEFAULTS = {
3743
'lr': lr,
3844
'betas': betas,
@@ -104,10 +110,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
104110
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1.0 - beta2)
105111

106112
de_nom = exp_avg_sq.sqrt().clamp_(min=group['eps'])
107-
if group['step'] == 2:
108-
exp_avg.addcdiv_(grad, de_nom)
109-
else:
110-
exp_avg.mul_(beta1).addcdiv_(grad, de_nom, value=1.0 - beta1)
113+
114+
normed_grad = grad.div(de_nom)
115+
if self.clip_lambda is not None:
116+
clip = self.clip_lambda(group['step'])
117+
normed_grad.clamp_(-clip, clip)
118+
119+
exp_avg.lerp_(normed_grad, weight=1.0 - beta1)
111120

112121
p.add_(exp_avg, alpha=-group['lr'])
113122

pytorch_optimizer/optimizer/galore.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pytorch_optimizer.base.optimizer import BaseOptimizer
88
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
99

10-
PROJECTION_TYPE = Literal['std', 'reverse_std', 'right', 'left', 'full']
10+
PROJECTION_TYPE = Literal['std', 'reverse_std', 'right', 'left', 'full', 'random']
1111

1212

1313
class GaLoreProjector:
@@ -16,8 +16,8 @@ class GaLoreProjector:
1616
:param rank: int. low rank to project.
1717
:param update_proj_gap: int. num steps to update the projection.
1818
:param scale: float. scale factor.
19-
:param projection_type: PROJECTION_TYPE. type of projection. 'std', 'reverse_std', 'right', 'left', 'full' are
20-
supported.
19+
:param projection_type: PROJECTION_TYPE. type of projection. 'std', 'reverse_std', 'right', 'left', 'full' and
20+
'random' are supported.
2121
"""
2222

2323
def __init__(
@@ -101,6 +101,14 @@ def get_low_rank_grad_full(self, grad: torch.Tensor, steps: int) -> torch.Tensor
101101
self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='full')
102102
return torch.matmul(self.ortho_matrix[0].t(), grad) @ self.ortho_matrix[1].t()
103103

104+
def get_low_rank_grad_random(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
105+
is_right: bool = grad.size(0) >= grad.size(1)
106+
if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
107+
self.ortho_matrix = self.get_orthogonal_matrix(
108+
grad, self.rank, projection_type='right' if is_right else 'left'
109+
)
110+
return torch.matmul(grad, self.ortho_matrix.t()) if is_right else torch.matmul(self.ortho_matrix.t(), grad)
111+
104112
def project(self, full_rank_grad: torch.Tensor, steps: int) -> torch.Tensor:
105113
if self.projection_type == 'std':
106114
return self.get_low_rank_grad_std(full_rank_grad, steps)
@@ -112,6 +120,8 @@ def project(self, full_rank_grad: torch.Tensor, steps: int) -> torch.Tensor:
112120
return self.get_low_rank_grad_left(full_rank_grad, steps)
113121
if self.projection_type == 'full':
114122
return self.get_low_rank_grad_full(full_rank_grad, steps)
123+
if self.projection_type == 'random':
124+
return self.get_low_rank_grad_random(full_rank_grad, steps)
115125
raise NotImplementedError
116126

117127
def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor:
@@ -133,6 +143,12 @@ def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor:
133143
return torch.matmul(self.ortho_matrix, low_rank_grad) * self.scale
134144
if self.projection_type == 'full':
135145
return torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1].t() * self.scale
146+
if self.projection_type == 'random':
147+
return (
148+
torch.matmul(low_rank_grad, self.ortho_matrix.t())
149+
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]
150+
else torch.matmul(self.ortho_matrix, low_rank_grad)
151+
) * self.scale
136152

137153
raise NotImplementedError
138154

pytorch_optimizer/optimizer/lion.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class Lion(BaseOptimizer):
1818
:param use_gc: bool. use gradient centralization.
1919
:param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
2020
:param adanorm: bool. whether to use the AdaNorm variant.
21+
:param cautious: bool. whether to use the Cautious variant.
2122
"""
2223

2324
def __init__(
@@ -31,13 +32,15 @@ def __init__(
3132
use_gc: bool = False,
3233
r: float = 0.95,
3334
adanorm: bool = False,
35+
cautious: bool = False,
3436
**kwargs,
3537
):
3638
self.validate_learning_rate(lr)
3739
self.validate_betas(betas)
3840
self.validate_non_negative(weight_decay, 'weight_decay')
3941

4042
self.use_gc = use_gc
43+
self.cautious = cautious
4144

4245
defaults: DEFAULTS = {
4346
'lr': lr,
@@ -114,6 +117,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
114117
update.mul_(beta1).add_(grad, alpha=1.0 - beta1).sign_()
115118
exp_avg.mul_(beta2).add_(s_grad, alpha=1.0 - beta2)
116119

120+
if self.cautious:
121+
self.apply_cautious(update, grad)
122+
117123
p.add_(update, alpha=-group['lr'])
118124

119125
return loss

tests/constants.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@
375375
(AdamS, {'lr': 1e0, 'weight_decay': 1e-3, 'ams_bound': True}, 20),
376376
(AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'scale_parameter': False}, 100),
377377
(AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'ams_bound': True}, 120),
378+
(AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'cautious': True}, 70),
378379
(AdaFactor, {'lr': 1e1, 'betas': (None, 0.999), 'weight_decay': 1e-3}, 40),
379380
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3}, 10),
380381
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'rebound': 'belief'}, 10),
@@ -383,6 +384,7 @@
383384
(Lion, {'lr': 5e-1, 'weight_decay': 1e-3}, 5),
384385
(Lion, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decouple': False}, 5),
385386
(Lion, {'lr': 5e-1, 'weight_decay': 1e-3, 'use_gc': True}, 10),
387+
(Lion, {'lr': 5e-1, 'weight_decay': 1e-3, 'cautious': True}, 5),
386388
(AliG, {'max_lr': 5e-1, 'momentum': 0.9}, 5),
387389
(AliG, {'max_lr': 5e-1, 'momentum': 0.9, 'adjusted_momentum': True}, 5),
388390
(SM3, {'lr': 5e-1, 'momentum': 0.9, 'beta': 0.9}, 5),
@@ -469,6 +471,11 @@
469471
{'lr': 5e-1, 'weight_decay': 1e-3, 'rank': 2, 'scale': 1.0, 'update_proj_gap': 2, 'projection_type': 'full'},
470472
5,
471473
),
474+
(
475+
GaLore,
476+
{'lr': 1e0, 'weight_decay': 1e-3, 'rank': 2, 'scale': 1.0, 'update_proj_gap': 1, 'projection_type': 'random'},
477+
5,
478+
),
472479
(Adalite, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
473480
(ScheduleFreeSGD, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
474481
(ScheduleFreeAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
@@ -478,8 +485,9 @@
478485
(Kate, {'lr': 5e-2}, 10),
479486
(StableAdamW, {'lr': 1e0}, 5),
480487
(AdamG, {'lr': 1e0}, 20),
481-
(AdEMAMix, {'lr': 1e0}, 5),
482-
(AdEMAMix, {'lr': 1e0, 't_alpha_beta3': 5}, 5),
488+
(AdEMAMix, {'lr': 1e0}, 3),
489+
(AdEMAMix, {'lr': 1e0, 't_alpha_beta3': 5}, 3),
490+
(AdEMAMix, {'lr': 1e0, 'cautious': True}, 2),
483491
(
484492
SOAP,
485493
{'lr': 1e0, 'shampoo_beta': 0.95, 'precondition_frequency': 1, 'merge_dims': False, 'precondition_1d': True},

0 commit comments

Comments
 (0)