Skip to content

Commit c9e38f6

Browse files
authored
Merge pull request #116 from kozistr/feature/optimizers
[Feature] Implement Ali-G optimizer
2 parents 26b8b19 + 50295aa commit c9e38f6

18 files changed

+288
-66
lines changed

README.rst

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,25 @@ Also, you can load the optimizer via `torch.hub`
6565
opt = torch.hub.load('kozistr/pytorch_optimizer', 'adamp')
6666
optimizer = opt(model.parameters())
6767

68+
If you want to build the optimizer with parameters & configs, there's `create_optimizer()` API.
6869

69-
And you can check the supported optimizers & lr schedulers.
70+
::
71+
72+
from pytorch_optimizer import create_optimizer
73+
74+
optimizer = create_optimizer(
75+
model,
76+
'adamp',
77+
lr=1e-2,
78+
weight_decay=1e-3,
79+
use_gc=True,
80+
use_lookahead=True,
81+
)
82+
83+
Supported Optimizers
84+
--------------------
85+
86+
You can check the supported optimizers & lr schedulers.
7087

7188
::
7289

@@ -75,10 +92,6 @@ And you can check the supported optimizers & lr schedulers.
7592
supported_optimizers = get_supported_optimizers()
7693
supported_lr_schedulers = get_supported_lr_schedulers()
7794

78-
79-
Supported Optimizers
80-
--------------------
81-
8295
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
8396
| Optimizer | Description | Official Code | Paper |
8497
+==============+=================================================================================================+===================================================================================+===============================================================================================+
@@ -124,6 +137,8 @@ Supported Optimizers
124137
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
125138
| Lion | *Symbolic Discovery of Optimization Algorithms* | `github <https://github.com/google/automl/tree/master/lion>`__ | `https://arxiv.org/abs/2302.06675 <https://arxiv.org/abs/2302.06675>`__ |
126139
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
140+
| Ali-G | *Adaptive Learning Rates for Interpolation with Gradients* | `github <https://github.com/oval-group/ali-g>`__ | `https://arxiv.org/abs/1906.05661 <https://arxiv.org/abs/1906.05661>`__ |
141+
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
127142

128143
Useful Resources
129144
----------------
@@ -327,6 +342,8 @@ Citations
327342

328343
`Lion <https://github.com/google/automl/tree/master/lion#citation>`__
329344

345+
`Ali-G <https://github.com/oval-group/ali-g#adaptive-learning-rates-for-interpolation-with-gradients>`__
346+
330347
Citation
331348
--------
332349

@@ -338,7 +355,7 @@ Or you can get from "cite this repository" button.
338355
@software{Kim_pytorch_optimizer_Bunch_of_2022,
339356
author = {Kim, Hyeongchan},
340357
month = {1},
341-
title = {{pytorch_optimizer: Bunch of optimizer implementations in PyTorch with clean-code, strict types}},
358+
title = {{pytorch_optimizer: optimizer & lr scheduler implementations in PyTorch}},
342359
version = {1.0.0},
343360
year = {2022}
344361
}

docs/optimizer_api.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,11 @@ Lion
272272

273273
.. autoclass:: pytorch_optimizer.Lion
274274
:members:
275+
276+
.. _AliG:
277+
278+
AliG
279+
----
280+
281+
.. autoclass:: pytorch_optimizer.AliG
282+
:members:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ target-version = "py39"
8989
"./tests/test_load_lr_schedulers.py" = ["D", "S101"]
9090
"./tests/test_lr_schedulers.py" = ["D"]
9191
"./tests/test_lr_scheduler_parameters.py" = ["D", "S101"]
92+
"./tests/test_create_optimizer.py" = ["D"]
9293
"./pytorch_optimizer/__init__.py" = ["F401"]
9394
"./pytorch_optimizer/lr_scheduler/__init__.py" = ["F401"]
9495

pytorch_optimizer/__init__.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# ruff: noqa
22
from typing import Dict, List
33

4-
from pytorch_optimizer.base.types import OPTIMIZER, SCHEDULER
4+
from torch import nn
5+
6+
from pytorch_optimizer.base.types import OPTIMIZER, PARAMETERS, SCHEDULER
57
from pytorch_optimizer.experimental.deberta_v3_lr_scheduler import deberta_v3_large_lr_scheduler
68
from pytorch_optimizer.lr_scheduler import (
79
ConstantLR,
@@ -23,6 +25,7 @@
2325
from pytorch_optimizer.optimizer.adan import Adan
2426
from pytorch_optimizer.optimizer.adapnm import AdaPNM
2527
from pytorch_optimizer.optimizer.agc import agc
28+
from pytorch_optimizer.optimizer.alig import AliG
2629
from pytorch_optimizer.optimizer.apollo import Apollo
2730
from pytorch_optimizer.optimizer.dadapt import DAdaptAdaGrad, DAdaptAdam, DAdaptSGD
2831
from pytorch_optimizer.optimizer.diffgrad import DiffGrad
@@ -100,6 +103,7 @@
100103
Apollo,
101104
NovoGrad,
102105
Lion,
106+
AliG,
103107
]
104108
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
105109

@@ -129,6 +133,49 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
129133
return OPTIMIZERS[optimizer]
130134

131135

136+
def create_optimizer(
137+
model: nn.Module,
138+
optimizer_name: str,
139+
lr: float = 1e-3,
140+
weight_decay: float = 0.0,
141+
wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
142+
use_lookahead: bool = False,
143+
**kwargs,
144+
):
145+
r"""Build optimizer.
146+
147+
:param model: nn.Module. model.
148+
:param optimizer_name: str. name of optimizer.
149+
:param lr: float. learning rate.
150+
:param weight_decay: float. weight decay.
151+
:param wd_ban_list: List[str]. weight decay ban list by layer.
152+
:param use_lookahead: bool. use lookahead.
153+
"""
154+
optimizer_name = optimizer_name.lower()
155+
156+
if weight_decay > 0.0:
157+
parameters = get_optimizer_parameters(model, weight_decay, wd_ban_list)
158+
else:
159+
parameters = model.parameters()
160+
161+
optimizer = load_optimizer(optimizer_name)
162+
163+
if optimizer_name == 'alig':
164+
optimizer = optimizer(parameters, max_lr=lr, **kwargs)
165+
else:
166+
optimizer = optimizer(parameters, lr=lr, **kwargs)
167+
168+
if use_lookahead:
169+
optimizer = Lookahead(
170+
optimizer,
171+
k=kwargs['k'] if 'k' in kwargs else 5,
172+
alpha=kwargs['alpha'] if 'alpha' in kwargs else 0.5,
173+
pullback_momentum=kwargs['pullback_momentum'] if 'pullback_momentum' in kwargs else 'none',
174+
)
175+
176+
return optimizer
177+
178+
132179
def load_lr_scheduler(lr_scheduler: str) -> SCHEDULER:
133180
lr_scheduler: str = lr_scheduler.lower()
134181

pytorch_optimizer/optimizer/adams.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,7 @@ def __init__(
3939

4040
self.validate_parameters()
4141

42-
defaults: DEFAULTS = {
43-
'lr': lr,
44-
'betas': betas,
45-
'weight_decay': weight_decay,
46-
'eps': eps,
47-
}
42+
defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay, 'eps': eps}
4843
super().__init__(params, defaults)
4944

5045
def validate_parameters(self):
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from typing import Callable, Optional
2+
3+
import torch
4+
from torch.optim.optimizer import Optimizer
5+
6+
from pytorch_optimizer.base.exception import NoSparseGradientError
7+
from pytorch_optimizer.base.optimizer import BaseOptimizer
8+
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
9+
10+
11+
class AliG(Optimizer, BaseOptimizer):
12+
r"""Adaptive Learning Rates for Interpolation with Gradients.
13+
14+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
15+
:param max_lr: Optional[float]. max learning rate.
16+
:param projection_fn : Callable. projection function to enforce constraints.
17+
:param momentum: float. momentum.
18+
:param adjusted_momentum: bool. if True, use pytorch-like momentum, instead of standard Nesterov momentum.
19+
:param eps: float. term added to the denominator to improve numerical stability.
20+
"""
21+
22+
def __init__(
23+
self,
24+
params: PARAMETERS,
25+
max_lr: Optional[float] = None,
26+
projection_fn: Optional[Callable] = None,
27+
momentum: float = 0.0,
28+
adjusted_momentum: bool = False,
29+
eps: float = 1e-5,
30+
):
31+
self.max_lr = max_lr
32+
self.projection_fn = projection_fn
33+
self.momentum = momentum
34+
self.adjusted_momentum = adjusted_momentum
35+
self.eps = eps
36+
37+
self.validate_parameters()
38+
39+
defaults: DEFAULTS = {'max_lr': max_lr, 'momentum': momentum}
40+
super().__init__(params, defaults)
41+
42+
if self.projection_fn is not None:
43+
self.projection_fn()
44+
45+
def validate_parameters(self):
46+
self.validate_momentum(self.momentum)
47+
self.validate_epsilon(self.eps)
48+
49+
@property
50+
def __str__(self) -> str:
51+
return 'AliG'
52+
53+
@torch.no_grad()
54+
def reset(self):
55+
for group in self.param_groups:
56+
for p in group['params']:
57+
state = self.state[p]
58+
59+
state['momentum_buffer'] = torch.zeros_like(p)
60+
61+
@torch.no_grad()
62+
def compute_step_size(self, loss: float) -> float:
63+
r"""Compute step_size."""
64+
global_grad_norm: float = 0
65+
66+
for group in self.param_groups:
67+
for p in group['params']:
68+
if p.grad is not None:
69+
global_grad_norm += p.grad.norm().pow(2).item()
70+
71+
return loss / (global_grad_norm + self.eps)
72+
73+
@torch.no_grad()
74+
def step(self, closure: CLOSURE = None) -> LOSS:
75+
if closure is None:
76+
raise ValueError('[-] AliG optimizer needs closure. (eg. `optimizer.step(lambda: float(loss))`).')
77+
78+
loss = closure()
79+
80+
un_clipped_step_size: float = self.compute_step_size(loss)
81+
82+
for group in self.param_groups:
83+
step_size = group['step_size'] = (
84+
min(un_clipped_step_size, group['max_lr']) if group['max_lr'] is not None else un_clipped_step_size
85+
)
86+
momentum = group['momentum']
87+
88+
for p in group['params']:
89+
if p.grad is None:
90+
continue
91+
92+
grad = p.grad
93+
if grad.is_sparse:
94+
raise NoSparseGradientError(self.__str__)
95+
96+
state = self.state[p]
97+
if len(state) == 0 and momentum > 0.0:
98+
state['momentum_buffer'] = torch.zeros_like(p)
99+
100+
p.add_(grad, alpha=-step_size)
101+
102+
if momentum > 0.0:
103+
buffer = state['momentum_buffer']
104+
105+
if self.adjusted_momentum:
106+
buffer.mul_(momentum).sub_(grad)
107+
p.add_(buffer, alpha=step_size * momentum)
108+
else:
109+
buffer.mul_(momentum).add_(grad, alpha=-step_size)
110+
p.add_(buffer, alpha=momentum)
111+
112+
if self.projection_fn is not None:
113+
self.projection_fn()
114+
115+
return loss

pytorch_optimizer/optimizer/gsam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def loss_fn(predictions, targets):
3333
lr_scheduler.step()
3434
optimizer.update_rho_t()
3535
36-
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
36+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
3737
:param base_optimizer: Optimizer. base optimizer.
3838
:param model: nn.Module. model.
3939
:param alpha: float. rho alpha.

pytorch_optimizer/optimizer/lion.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,7 @@ def __init__(
3131

3232
self.validate_parameters()
3333

34-
defaults: DEFAULTS = {
35-
'lr': lr,
36-
'betas': betas,
37-
'weight_decay': weight_decay,
38-
}
34+
defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay}
3935
super().__init__(params, defaults)
4036

4137
def validate_parameters(self):

pytorch_optimizer/optimizer/madgrad.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,7 @@ def __init__(
4242

4343
self.validate_parameters()
4444

45-
defaults: DEFAULTS = {
46-
'lr': lr,
47-
'weight_decay': weight_decay,
48-
'momentum': momentum,
49-
'eps': eps,
50-
}
45+
defaults: DEFAULTS = {'lr': lr, 'weight_decay': weight_decay, 'momentum': momentum, 'eps': eps}
5146
super().__init__(params, defaults)
5247

5348
def validate_parameters(self):

pytorch_optimizer/optimizer/novograd.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,7 @@ def __init__(
3939

4040
self.validate_parameters()
4141

42-
defaults: DEFAULTS = {
43-
'lr': lr,
44-
'betas': betas,
45-
'weight_decay': weight_decay,
46-
'eps': eps,
47-
}
42+
defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay, 'eps': eps}
4843
super().__init__(params, defaults)
4944

5045
def validate_parameters(self):

0 commit comments

Comments
 (0)