Skip to content

Commit 81acb1e

Browse files
authored
Merge pull request #332 from kozistr/fix/wrapper-optimizer
[Fix] Wrapper optimizers
2 parents 7a9377a + dae7169 commit 81acb1e

File tree

12 files changed

+99
-42
lines changed

12 files changed

+99
-42
lines changed

docs/changelogs/v3.3.5.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
11
### Change Log
22

3+
### Feature
4+
5+
* Implement `FOCUS` optimizer. (#330, #331)
6+
* [First Order Concentrated Updating Scheme](https://arxiv.org/abs/2501.12243)
7+
8+
### Update
9+
10+
* Support `OrthoGrad` variant to `Ranger25`. (#332)
11+
312
### Fix
413

514
* Add the missing `state` property in `OrthoGrad` optimizer. (#326, #327)
15+
* Add the missing `state_dict`, and `load_state_dict` methods to `TRAC` and `OrthoGrad` optimizers. (#332)
16+
* Skip when the gradient is sparse in `OrthoGrad` optimizer. (#332)
617

718
### Contributions
819

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,10 @@ select = [
9797
"TID", "ARG", "ERA", "RUF", "YTT", "PL", "Q"
9898
]
9999
ignore = [
100-
"A005", "B905", "D100", "D102", "D104", "D105", "D107", "D203", "D213", "D413", "PIE790", "PLR0912", "PLR0913",
101-
"PLR0915", "PLR2004", "RUF013", "Q003", "ARG002",
100+
"A005", "B905",
101+
"D100", "D102", "D104", "D105", "D107", "D203", "D213", "D413",
102+
"PLR0912", "PLR0913", "PLR0915", "PLR2004",
103+
"Q003", "ARG002",
102104
]
103105
fixable = ["ALL"]
104106
unfixable = ["F401"]

pytorch_optimizer/base/optimizer.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,16 @@
66
from torch.optim import Optimizer
77

88
from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError
9-
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, HUTCHINSON_G, LOSS, PARAMETERS, STATE
9+
from pytorch_optimizer.base.types import (
10+
BETAS,
11+
CLOSURE,
12+
DEFAULTS,
13+
HUTCHINSON_G,
14+
LOSS,
15+
OPTIMIZER_INSTANCE_OR_CLASS,
16+
PARAMETERS,
17+
STATE,
18+
)
1019

1120

1221
class BaseOptimizer(ABC, Optimizer):
@@ -15,6 +24,18 @@ class BaseOptimizer(ABC, Optimizer):
1524
def __init__(self, params: PARAMETERS, defaults: DEFAULTS) -> None:
1625
super().__init__(params, defaults)
1726

27+
@staticmethod
28+
def load_optimizer(optimizer: OPTIMIZER_INSTANCE_OR_CLASS, **kwargs) -> Optimizer:
29+
r"""Build torch.optim.Optimizer class."""
30+
if isinstance(optimizer, Optimizer):
31+
return optimizer
32+
33+
if 'params' in kwargs:
34+
params = kwargs.pop('params')
35+
return optimizer(params, **kwargs)
36+
37+
raise ValueError('need to pass `params` when you pass the `torch.optim.Optimizer` instance.')
38+
1839
@staticmethod
1940
@torch.no_grad()
2041
def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tensor]) -> None:

pytorch_optimizer/loss/jaccard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class JaccardLoss(_Loss):
4949
def __init__(
5050
self,
5151
mode: CLASS_MODE,
52-
classes: List[int] = None,
52+
classes: Optional[List[int]] = None,
5353
log_loss: bool = False,
5454
from_logits: bool = True,
5555
label_smooth: float = 0.0,
@@ -59,7 +59,7 @@ def __init__(
5959

6060
if classes is not None:
6161
if mode == 'binary':
62-
raise ValueError('[-] Masking classes is not supported with mode=binary')
62+
raise ValueError('masking classes is not supported with mode=binary')
6363

6464
classes = torch.LongTensor(classes)
6565

pytorch_optimizer/optimizer/experimental/ranger25.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
class Ranger25(BaseOptimizer):
1212
r"""Mixin' every fancy optimizer hacks.
1313
14-
ADOPT + AdEMAMix + Cautious + StableAdamW + Adam-Atan2
14+
ADOPT + AdEMAMix + Cautious + StableAdamW + Adam-Atan2 + OrthoGrad
1515
1616
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
1717
:param lr: float. learning rate.
@@ -23,6 +23,7 @@ class Ranger25(BaseOptimizer):
2323
:param t_alpha_beta3: Optional[float]. total number of iterations is preferred when needed.
2424
:param cautious: bool. whether to use the Cautious variant.
2525
:param stable_adamw: bool. whether to use stable AdamW variant.
26+
:param orthograd: bool. whether to use orthograd variant.
2627
:param eps: Optional[float]. term added to the denominator to improve numerical stability. when eps is None and
2728
stable_adamw is False, adam-atan2 feature will be used.
2829
"""
@@ -39,6 +40,7 @@ def __init__(
3940
t_alpha_beta3: Optional[float] = None,
4041
cautious: bool = True,
4142
stable_adamw: bool = True,
43+
orthograd: bool = True,
4244
eps: Optional[float] = 1e-8,
4345
**kwargs,
4446
):
@@ -51,6 +53,7 @@ def __init__(
5153

5254
self.cautious = cautious
5355
self.stable_adamw: bool = stable_adamw if isinstance(eps, float) else False
56+
self.orthograd = orthograd
5457

5558
defaults: DEFAULTS = {
5659
'lr': lr,
@@ -97,13 +100,32 @@ def schedule_beta3(t_alpha_beta3: Optional[float], step: int, beta1: float, beta
97100
beta3,
98101
)
99102

103+
@torch.no_grad()
104+
def orthogonalize_gradients(self, params, eps: float = 1e-16) -> None:
105+
for p in params:
106+
if p.grad is None or p.grad.is_sparse:
107+
continue
108+
109+
w = p.view(-1)
110+
g = p.grad.view(-1)
111+
112+
proj = torch.dot(w, g).div_(torch.dot(w, w).add_(eps))
113+
g_ortho = g.to(dtype=torch.float32, copy=True).sub_(w, alpha=proj)
114+
g_ortho_scaled = g_ortho.mul_(g.norm(2).div_(g_ortho.norm(2).add_(eps)))
115+
116+
p.grad.copy_(g_ortho_scaled.view_as(p.grad))
117+
100118
@torch.no_grad()
101119
def step(self, closure: CLOSURE = None) -> LOSS:
102120
loss: LOSS = None
103121
if closure is not None:
104122
with torch.enable_grad():
105123
loss = closure()
106124

125+
if self.orthograd:
126+
for group in self.param_groups:
127+
self.orthogonalize_gradients(group['params'])
128+
107129
for group in self.param_groups:
108130
if 'step' in group:
109131
group['step'] += 1

pytorch_optimizer/optimizer/lookahead.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,7 @@ def __init__(
2929
self.validate_range(alpha, 'alpha', 0.0, 1.0)
3030
self.validate_options(pullback_momentum, 'pullback_momentum', ['none', 'reset', 'pullback'])
3131

32-
if isinstance(optimizer, Optimizer):
33-
self.optimizer = optimizer
34-
elif 'params' in kwargs:
35-
params = kwargs.pop('params')
36-
self.optimizer = optimizer(params, **kwargs)
37-
else:
38-
raise ValueError('Need to pass `params` when you pass the torch.optim.Optimizer instance.')
32+
self.optimizer: Optimizer = self.load_optimizer(optimizer, **kwargs)
3933

4034
self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
4135
self._optimizer_step_post_hooks: Dict[int, Callable] = {}

pytorch_optimizer/optimizer/orthograd.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch.optim import Optimizer
55

66
from pytorch_optimizer.base.optimizer import BaseOptimizer
7-
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS
7+
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS, STATE
88

99

1010
class OrthoGrad(BaseOptimizer):
@@ -20,13 +20,7 @@ def __init__(self, optimizer: OPTIMIZER_INSTANCE_OR_CLASS, **kwargs) -> None:
2020
self._optimizer_step_post_hooks: Dict[int, Callable] = {}
2121
self.eps: float = 1e-30
2222

23-
if isinstance(optimizer, Optimizer):
24-
self.optimizer = optimizer
25-
elif 'params' in kwargs:
26-
params = kwargs.pop('params')
27-
self.optimizer = optimizer(params, **kwargs)
28-
else:
29-
raise ValueError('Need to pass `params` when you pass the torch.optim.Optimizer instance.')
23+
self.optimizer: Optimizer = self.load_optimizer(optimizer, **kwargs)
3024

3125
self.defaults: DEFAULTS = self.optimizer.defaults
3226

@@ -38,9 +32,15 @@ def param_groups(self):
3832
return self.optimizer.param_groups
3933

4034
@property
41-
def state(self):
35+
def state(self) -> STATE:
4236
return self.optimizer.state
4337

38+
def state_dict(self) -> STATE:
39+
return self.optimizer.state_dict()
40+
41+
def load_state_dict(self, state_dict: STATE) -> None:
42+
self.optimizer.load_state_dict(state_dict)
43+
4444
@torch.no_grad()
4545
def zero_grad(self) -> None:
4646
self.optimizer.zero_grad(set_to_none=True)
@@ -52,7 +52,7 @@ def reset(self):
5252
@torch.no_grad()
5353
def orthogonalize_gradients(self, params) -> None:
5454
for p in params:
55-
if p.grad is None:
55+
if p.grad is None or p.grad.is_sparse:
5656
continue
5757

5858
w = p.view(-1)

pytorch_optimizer/optimizer/shampoo_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def __init__(self, *args):
3131

3232
def add_statistics(self, grad: torch.Tensor, unused_beta2: float) -> None:
3333
r"""Add the statistics."""
34-
pass
3534

3635
def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor:
3736
r"""Get preconditioned gradient."""

pytorch_optimizer/optimizer/trac.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch.optim import Optimizer
66

77
from pytorch_optimizer.base.optimizer import BaseOptimizer
8-
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS
8+
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS, STATE
99

1010

1111
def polyval(x: torch.Tensor, coef: torch.Tensor) -> torch.Tensor:
@@ -112,23 +112,17 @@ def __init__(
112112
self.validate_non_negative(s_prev, 's_prev')
113113
self.validate_non_negative(eps, 'eps')
114114

115-
if isinstance(optimizer, Optimizer):
116-
self.optimizer = optimizer
117-
elif 'params' in kwargs:
118-
params = kwargs.pop('params')
119-
self.optimizer = optimizer(params, **kwargs)
120-
else:
121-
raise ValueError('Need to pass `params` when you pass the torch.optim.Optimizer instance.')
122-
123115
self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
124116
self._optimizer_step_post_hooks: Dict[int, Callable] = {}
125117

126-
self.erf = ERF1994(num_coefs=num_coefs)
118+
self.optimizer: Optimizer = self.load_optimizer(optimizer, **kwargs)
119+
127120
self.betas = betas
128121
self.s_prev = s_prev
129122
self.eps = eps
130123

131-
self.f_term = self.s_prev / self.erf_imag(1.0 / torch.sqrt(torch.tensor(2.0)))
124+
self.erf: nn.Module = ERF1994(num_coefs=num_coefs)
125+
self.f_term: torch.Tensor = self.s_prev / self.erf_imag(1.0 / torch.sqrt(torch.tensor(2.0)))
132126

133127
self.defaults: DEFAULTS = self.optimizer.defaults
134128

@@ -140,9 +134,15 @@ def param_groups(self):
140134
return self.optimizer.param_groups
141135

142136
@property
143-
def state(self):
137+
def state(self) -> STATE:
144138
return self.optimizer.state
145139

140+
def state_dict(self) -> STATE:
141+
return self.optimizer.state_dict()
142+
143+
def load_state_dict(self, state_dict: STATE) -> None:
144+
self.optimizer.load_state_dict(state_dict)
145+
146146
@torch.no_grad()
147147
def reset(self):
148148
device = self.param_groups[0]['params'][0].device

pytorch_optimizer/optimizer/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def is_deepspeed_zero3_enabled() -> bool:
4848
return is_deepspeed_zero3_enabled() # pragma: no cover
4949

5050
warnings.warn(
51-
'you need to install `transformers` to use `is_deepspeed_zero3_enabled` function. it\'ll return False.',
51+
'you need to install `transformers` to use `is_deepspeed_zero3_enabled` function. '
52+
'it will return False.',
5253
category=ImportWarning,
5354
stacklevel=2,
5455
)

0 commit comments

Comments
 (0)