Skip to content

Commit 733f246

Browse files
authored
Merge pull request #380 from kozistr/update/stuff
[Feature] Support complex parameter & `maximize` option for all optimizers
2 parents b26a78e + 12d0ef7 commit 733f246

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

107 files changed

+4037
-3036
lines changed

docs/changelogs/v3.5.2.md

Lines changed: 0 additions & 14 deletions
This file was deleted.

docs/changelogs/v3.6.0.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
## Change Log
2+
3+
### Feature
4+
5+
* Implement `Fira` optimizer. (#376)
6+
* [Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?](https://arxiv.org/abs/2410.01623)
7+
* Implement `RACS` and `Alice` optimizers. (#376)
8+
* [Towards Efficient Optimizer Design for LLM via Structured Fisher Approximation with a Low-Rank Extension](https://arxiv.org/abs/2502.07752)
9+
* Implement `VSGD` optimizer. (#377, #378)
10+
* [Variational Stochastic Gradient Descent for Deep Neural Networks](https://openreview.net/forum?id=xu4ATNjcdy)
11+
* Support complex parameters. (#370, #380)
12+
* Support `maximize` parameter. (#370, #380)
13+
14+
### Update
15+
16+
* Support 2D< Tensor for `RACS` and `Alice` optimizers. (#380)
17+
* Remove the auxiliary variants from the default parameters of the optimizers and change the name of the state and parameter. (#380)
18+
* `use_gc`, `adanorm`, `cautious`, `stable_adamw`, and `adam_debias` will be affected.
19+
* You can still use these variants by passing the parameters to `**kwargs`.
20+
* Notably, in case of `adanorm` variant, you need to pass `adanorm` (and `adanorm_r` for `r` option) parameter(s) to use this variant, and the name of the state will be changed from `exp_avg_norm` to `exp_avg_adanorm`.
21+
* Refactor `reset()` to `init_group()` method in the `BaseOptimizer` class. (#380)
22+
* Refactor `SAM` optimizer faimily. (#380)
23+
24+
### Fix
25+
26+
* Fix shape mismatch issues in the Galore projection for `reverse_std`, `right` and `full` projection types. (#376)

docs/optimizer.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@
116116
:docstring:
117117
:members:
118118

119+
::: pytorch_optimizer.Alice
120+
:docstring:
121+
:members:
122+
119123
::: pytorch_optimizer.AliG
120124
:docstring:
121125
:members:
@@ -316,6 +320,10 @@
316320
:docstring:
317321
:members:
318322

323+
::: pytorch_optimizer.RACS
324+
:docstring:
325+
:members:
326+
319327
::: pytorch_optimizer.RAdam
320328
:docstring:
321329
:members:

docs/qa.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@
1010

1111
## Q3) How to run visualizations?
1212

13-
Run `python3 -m examples.visualize_optimizers` on the project root.
13+
Run `make visualize` or `python3 -m examples.visualize_optimizers` on the project root.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import os
2+
3+
import pytorch_lightning as pl
4+
import torch
5+
from torch import nn
6+
from torch.optim import AdamW
7+
from torch.utils.data import DataLoader
8+
from torchvision.datasets import MNIST
9+
from torchvision.transforms import ToTensor
10+
11+
from pytorch_optimizer import Lookahead
12+
13+
14+
class LitAutoEncoder(pl.LightningModule):
15+
def __init__(self):
16+
super().__init__()
17+
18+
self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
19+
self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
20+
21+
def training_step(self, batch, batch_idx):
22+
x, y = batch
23+
x = x.view(x.size(0), -1)
24+
25+
z = self.encoder(x)
26+
x_hat = self.decoder(z)
27+
28+
loss = nn.functional.mse_loss(x_hat, x)
29+
30+
self.log('train_loss', loss)
31+
32+
return loss
33+
34+
def configure_optimizers(self):
35+
return Lookahead(AdamW(self.parameters(), lr=1e-3), k=5, alpha=0.5)
36+
37+
38+
def main():
39+
train_dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
40+
train_loader = DataLoader(train_dataset)
41+
42+
autoencoder = LitAutoEncoder()
43+
autoencoder.train()
44+
45+
if torch.cuda.is_available():
46+
autoencoder.cuda()
47+
48+
trainer = pl.Trainer(limit_train_batches=100, max_epochs=1)
49+
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
50+
51+
52+
if __name__ == '__main__':
53+
main()
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import os
2+
3+
import pytorch_lightning as pl
4+
import torch
5+
from torch import nn
6+
from torch.utils.data import DataLoader
7+
from torchvision.datasets import MNIST
8+
from torchvision.transforms import ToTensor
9+
10+
from pytorch_optimizer import SophiaH
11+
12+
13+
class LitAutoEncoder(pl.LightningModule):
14+
def __init__(self):
15+
super().__init__()
16+
17+
self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
18+
self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
19+
20+
self.automatic_optimization = False
21+
22+
def training_step(self, batch, batch_idx):
23+
opt = self.optimizers()
24+
opt.zero_grad()
25+
26+
x, y = batch
27+
x = x.view(x.size(0), -1)
28+
29+
z = self.encoder(x)
30+
x_hat = self.decoder(z)
31+
32+
loss = nn.functional.mse_loss(x_hat, x)
33+
34+
self.manual_backward(loss, create_graph=True)
35+
opt.step()
36+
37+
self.log('train_loss', loss)
38+
39+
return loss
40+
41+
def configure_optimizers(self):
42+
return SophiaH(self.parameters())
43+
44+
45+
def main():
46+
train_dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
47+
train_loader = DataLoader(train_dataset)
48+
49+
autoencoder = LitAutoEncoder()
50+
autoencoder.train()
51+
52+
if torch.cuda.is_available():
53+
autoencoder.cuda()
54+
55+
trainer = pl.Trainer(limit_train_batches=100, max_epochs=1)
56+
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
57+
58+
59+
if __name__ == '__main__':
60+
main()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ select = [
100100
ignore = [
101101
"B905",
102102
"D100", "D102", "D104", "D105", "D107", "D203", "D213", "D413",
103-
"PLR0912", "PLR0913", "PLR0915", "PLR2004",
103+
"PLR0912", "PLR0913", "PLR0915", "PLR2004", "PLW2901",
104104
"Q003", "ARG002",
105105
]
106106
fixable = ["ALL"]

pytorch_optimizer/base/exception.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,16 @@ def __init__(self, num_steps: int, step_type: str = ''):
4343
self.note: str = step_type if step_type else 'step'
4444
self.message: str = f'{self.note} must be positive. ({num_steps} > 0)'
4545
super().__init__(self.message)
46+
47+
48+
class NoComplexParameterError(Exception):
49+
"""Raised when the dtype of the parameter is complex.
50+
51+
:param optimizer_name: str. optimizer name.
52+
:param note: str. special conditions to note (default '').
53+
"""
54+
55+
def __init__(self, optimizer_name: str, note: str = ''):
56+
self.note: str = ' ' if not note else f' w/ {note} '
57+
self.message: str = f'{optimizer_name}{self.note}does not support complex parameter.'
58+
super().__init__(self.message)

pytorch_optimizer/base/optimizer.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
BETAS,
1111
CLOSURE,
1212
DEFAULTS,
13+
GROUP,
1314
HUTCHINSON_G,
1415
LOSS,
1516
OPTIMIZER_INSTANCE_OR_CLASS,
@@ -163,7 +164,10 @@ def apply_ams_bound(
163164
:param eps: float. epsilon.
164165
"""
165166
if ams_bound:
166-
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
167+
if torch.is_complex(max_exp_avg_sq):
168+
max_exp_avg_sq = torch.view_as_real(max_exp_avg_sq)
169+
170+
torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
167171
de_nom = max_exp_avg_sq.add(eps)
168172
else:
169173
de_nom = exp_avg_sq.add(eps)
@@ -195,7 +199,7 @@ def debias_beta(beta: float, step: int) -> float:
195199
def apply_adam_debias(adam_debias: bool, step_size: float, bias_correction1: float) -> float:
196200
r"""Apply AdamD variant.
197201
198-
:param adam_debias: bool. whether to apply AdamD.
202+
:param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
199203
:param step_size: float. step size.
200204
:param bias_correction1: float. bias_correction.
201205
"""
@@ -247,16 +251,19 @@ def get_adanorm_gradient(
247251
r"""Get AdaNorm gradient.
248252
249253
:param grad: torch.Tensor. gradient.
250-
:param adanorm: bool. whether to apply AdaNorm.
254+
:param adanorm: bool. whether to use the AdaNorm variant.
251255
:param exp_grad_norm: Optional[torch.Tensor]. exp_grad_norm.
252-
:param r: float. Optional[float]. momentum (ratio).
256+
:param r: Optional[float]. EMA factor. between 0.9 ~ 0.99 is preferred.
253257
"""
254258
if not adanorm or exp_grad_norm is None:
255259
return grad
256260

261+
if r is None:
262+
r = 0.95
263+
257264
grad_norm = torch.linalg.norm(grad)
258265

259-
exp_grad_norm.mul_(r).add_(grad_norm, alpha=1.0 - r)
266+
exp_grad_norm.mul(r).add_(grad_norm, alpha=1.0 - r)
260267

261268
return grad.mul(exp_grad_norm).div_(grad_norm) if exp_grad_norm > grad_norm else grad
262269

@@ -371,8 +378,27 @@ def validate_nus(self, nus: Union[float, Tuple[float, float]]) -> None:
371378
self.validate_range(nus[1], 'nu2', 0.0, 1.0, range_type='[]')
372379

373380
@abstractmethod
374-
def reset(self) -> None: # pragma: no cover
375-
raise NotImplementedError
381+
def init_group(self, group: GROUP, **kwargs) -> None: # pragma: no cover
382+
r"""Initialize the group of the optimizer and return is_complex."""
383+
return
384+
385+
@staticmethod
386+
def view_as_real(param, *state_and_grads) -> tuple:
387+
r"""View imaginary tensors as real tensors."""
388+
if torch.is_complex(param):
389+
param = torch.view_as_real(param)
390+
state_and_grads = tuple(
391+
torch.view_as_real(s) if (s is not None and torch.is_complex(s)) else s if s is not None else None
392+
for s in state_and_grads
393+
)
394+
395+
return param, *state_and_grads
396+
397+
@staticmethod
398+
def maximize_gradient(grad: torch.Tensor, maximize: bool = False) -> None:
399+
r"""Maximize the objective with respect to the params, instead of minimizing."""
400+
if maximize:
401+
grad.neg_()
376402

377403
def step(self, closure: CLOSURE = None) -> LOSS: # pragma: no cover
378404
raise NotImplementedError

pytorch_optimizer/base/type.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
LOSS = Optional[float]
99
BETAS = Union[Tuple[float, float], Tuple[float, float, float], Tuple[None, float]]
1010
DEFAULTS = Dict
11-
PARAMETERS = Optional[Union[Iterable[Dict], Iterable[torch.Tensor]]]
11+
GROUP = Dict
12+
PARAMETERS = Optional[Union[Iterable[GROUP], Iterable[torch.Tensor]]]
1213
STATE = Dict
1314
OPTIMIZER = Type[Optimizer]
1415
OPTIMIZER_INSTANCE_OR_CLASS = Union[OPTIMIZER, Optimizer]

0 commit comments

Comments
 (0)