Skip to content

Commit 1054960

Browse files
authored
Merge pull request #263 from kozistr/feature/trac-optimizer
[Feature] Implement TRAC optimizer
2 parents d00136f + e94290f commit 1054960

File tree

15 files changed

+483
-169
lines changed

15 files changed

+483
-169
lines changed

README.md

Lines changed: 79 additions & 78 deletions
Large diffs are not rendered by default.

docs/changelogs/v3.1.1.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
## Change Log
2+
3+
### Feature
4+
5+
* Implement `TRAC` optimizer. (#263)
6+
* [Fast TRAC: A Parameter-Free Optimizer for Lifelong Reinforcement Learning](https://arxiv.org/abs/2405.16642)
7+
* Support `AdamW` optimizer via `create_optimizer()`. (#263)
8+
9+
### Bug
10+
11+
* Fix to handle the optimizers that only take the `model` instead of the parameters in `create_optimizer()`. (#263)

docs/index.md

Lines changed: 79 additions & 78 deletions
Large diffs are not rendered by default.

docs/optimizer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,10 @@
328328
:docstring:
329329
:members:
330330

331+
::: pytorch_optimizer.TRAC
332+
:docstring:
333+
:members:
334+
331335
::: pytorch_optimizer.WSAM
332336
:docstring:
333337
:members:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ keywords = [
1717
"GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG",
1818
"Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21",
1919
"RotoGrad", "SAM", "ScheduleFreeSGD", "ScheduleFreeAdamW", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD",
20-
"SM3", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal",
20+
"SM3", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal",
2121
"FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge",
2222
"bitsandbytes", "WSD", "QGaLore",
2323
]

pytorch_optimizer/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch.cuda
66
from torch import nn
7+
from torch.optim import AdamW
78

89
from pytorch_optimizer.base.types import OPTIMIZER, PARAMETERS, SCHEDULER
910
from pytorch_optimizer.loss.bi_tempered import BinaryBiTemperedLogisticLoss, BiTemperedLogisticLoss
@@ -115,6 +116,7 @@
115116
from pytorch_optimizer.optimizer.srmm import SRMM
116117
from pytorch_optimizer.optimizer.swats import SWATS
117118
from pytorch_optimizer.optimizer.tiger import Tiger
119+
from pytorch_optimizer.optimizer.trac import TRAC
118120
from pytorch_optimizer.optimizer.utils import (
119121
clip_grad_norm,
120122
disable_running_stats,
@@ -131,6 +133,7 @@
131133
HAS_Q_GALORE: bool = find_spec('q-galore-torch') is not None
132134

133135
OPTIMIZER_LIST: List[OPTIMIZER] = [
136+
AdamW,
134137
AdaBelief,
135138
AdaBound,
136139
PID,
@@ -350,6 +353,8 @@ def create_optimizer(
350353

351354
if optimizer_name == 'alig':
352355
optimizer = optimizer(parameters, max_lr=lr, **kwargs)
356+
elif optimizer_name in {'lomo', 'adalomo', 'adammini'}:
357+
optimizer = optimizer(model, lr=lr, **kwargs)
353358
else:
354359
optimizer = optimizer(parameters, lr=lr, **kwargs)
355360

pytorch_optimizer/optimizer/lookahead.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(
2222
k: int = 5,
2323
alpha: float = 0.5,
2424
pullback_momentum: str = 'none',
25-
):
25+
) -> None:
2626
self.validate_positive(k, 'k')
2727
self.validate_range(alpha, 'alpha', 0.0, 1.0)
2828
self.validate_options(pullback_momentum, 'pullback_momentum', ['none', 'reset', 'pullback'])
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
from typing import Callable, Dict, List, Tuple
2+
3+
import torch
4+
from torch import nn
5+
6+
from pytorch_optimizer.base.optimizer import BaseOptimizer
7+
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER
8+
9+
10+
def polyval(x: torch.Tensor, coef: torch.Tensor) -> torch.Tensor:
11+
r"""Implement of the Horner scheme to evaluate a polynomial.
12+
13+
taken from https://discuss.pytorch.org/t/polynomial-evaluation-by-horner-rule/67124
14+
15+
:param x: torch.Tensor. variable.
16+
:param coef: torch.Tensor. coefficients of the polynomial.
17+
"""
18+
result = coef[0].clone()
19+
20+
for c in coef[1:]:
21+
result = (result * x) + c
22+
23+
return result[0]
24+
25+
26+
class ERF1994(nn.Module):
27+
r"""Implementation of ERF1994.
28+
29+
:param num_coefs: int. The number of polynomial coefficients to use in the approximation.
30+
"""
31+
32+
def __init__(self, num_coefs: int = 128) -> None:
33+
super().__init__()
34+
35+
self.n: int = num_coefs
36+
37+
self.i: torch.Tensor = torch.complex(torch.tensor(0.0), torch.tensor(1.0))
38+
self.m = 2 * self.n
39+
self.m2 = 2 * self.m
40+
self.k = torch.linspace(-self.m + 1, self.m - 1, self.m2 - 1)
41+
self.l = torch.sqrt(self.n / torch.sqrt(torch.tensor(2.0)))
42+
self.theta = self.k * torch.pi / self.m
43+
self.t = self.l * torch.tan(self.theta / 2.0)
44+
self.f = torch.exp(-self.t ** 2) * (self.l ** 2 + self.t ** 2) # fmt: skip
45+
self.a = torch.fft.fft(torch.fft.fftshift(self.f)).real / self.m2
46+
self.a = torch.flipud(self.a[1:self.n + 1]) # fmt: skip
47+
48+
def w_algorithm(self, z: torch.Tensor) -> torch.Tensor:
49+
r"""Compute the Faddeeva function of a complex number.
50+
51+
:param z: torch.Tensor. A tensor of complex numbers.
52+
"""
53+
self.l = self.l.to(z.device)
54+
self.i = self.i.to(z.device)
55+
self.a = self.a.to(z.device)
56+
57+
iz = self.i * z
58+
lp_iz, ln_iz = self.l + iz, self.l - iz
59+
60+
z_ = lp_iz / ln_iz
61+
p = polyval(z_.unsqueeze(0), self.a)
62+
return 2 * p / ln_iz.pow(2) + (1.0 / torch.sqrt(torch.tensor(torch.pi))) / ln_iz
63+
64+
def forward(self, z: torch.Tensor) -> torch.Tensor:
65+
r"""Compute the error function of a complex number.
66+
67+
:param z: torch.Tensor. A tensor of complex numbers.
68+
"""
69+
sign_r = torch.sign(z.real)
70+
sign_i = torch.sign(z.imag)
71+
z = torch.complex(torch.abs(z.real), torch.abs(z.imag))
72+
out = -torch.exp(torch.log(self.w_algorithm(z * self.i)) - z ** 2) + 1 # fmt: skip
73+
return torch.complex(out.real * sign_r, out.imag * sign_i)
74+
75+
76+
class TRAC(BaseOptimizer):
77+
r"""A Parameter-Free Optimizer for Lifelong Reinforcement Learning.
78+
79+
Example:
80+
-------
81+
Here's an example::
82+
83+
model = YourModel()
84+
optimizer = TRAC(AdamW(model.parameters()))
85+
86+
for input, output in data:
87+
optimizer.zero_grad()
88+
89+
loss = loss_fn(model(input), output)
90+
loss.backward()
91+
92+
optimizer.step()
93+
94+
:param optimizer: Optimizer. base optimizer.
95+
:param betas: List[float]. list of beta values.
96+
:param num_coefs: int. the number of polynomial coefficients to use in the approximation.
97+
:param s_prev: float. initial scale value.
98+
:param eps: float. term added to the denominator to improve numerical stability.
99+
"""
100+
101+
def __init__(
102+
self,
103+
optimizer: OPTIMIZER,
104+
betas: List[float] = (0.9, 0.99, 0.999, 0.9999, 0.99999, 0.999999),
105+
num_coefs: int = 128,
106+
s_prev: float = 1e-8,
107+
eps: float = 1e-8,
108+
):
109+
self.validate_positive(num_coefs, 'num_coefs')
110+
self.validate_non_negative(s_prev, 's_prev')
111+
self.validate_non_negative(eps, 'eps')
112+
113+
self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
114+
self._optimizer_step_post_hooks: Dict[int, Callable] = {}
115+
116+
self.erf = ERF1994(num_coefs=num_coefs)
117+
self.betas = betas
118+
self.s_prev = s_prev
119+
self.eps = eps
120+
121+
self.f_term = self.s_prev / self.erf_imag(1.0 / torch.sqrt(torch.tensor(2.0)))
122+
123+
self.optimizer = optimizer
124+
self.defaults: DEFAULTS = optimizer.defaults
125+
126+
def __str__(self) -> str:
127+
return 'TRAC'
128+
129+
@property
130+
def param_groups(self):
131+
return self.optimizer.param_groups
132+
133+
@property
134+
def state(self):
135+
return self.optimizer.state
136+
137+
@torch.no_grad()
138+
def reset(self):
139+
device = self.param_groups[0]['params'][0].device
140+
141+
self.state['trac'] = {
142+
'betas': torch.tensor(self.betas, device=device),
143+
's': torch.zeros(len(self.betas), device=device),
144+
'variance': torch.zeros(len(self.betas), device=device),
145+
'sigma': torch.full((len(self.betas),), 1e-8, device=device),
146+
'step': 0,
147+
}
148+
149+
for group in self.param_groups:
150+
for p in group['params']:
151+
self.state['trac'][p] = p.clone()
152+
153+
@torch.no_grad()
154+
def zero_grad(self) -> None:
155+
self.optimizer.zero_grad(set_to_none=True)
156+
157+
@torch.no_grad()
158+
def erf_imag(self, x: torch.Tensor) -> torch.Tensor:
159+
if not torch.is_floating_point(x):
160+
x = x.to(torch.float32)
161+
162+
ix = torch.complex(torch.zeros_like(x), x)
163+
164+
return self.erf(ix).imag
165+
166+
@torch.no_grad()
167+
def backup_params_and_grads(self) -> Tuple[Dict, Dict]:
168+
updates, grads = {}, {}
169+
170+
for group in self.param_groups:
171+
for p in group['params']:
172+
updates[p] = p.clone()
173+
grads[p] = p.grad.clone() if p.grad is not None else None
174+
175+
return updates, grads
176+
177+
@torch.no_grad()
178+
def trac_step(self, updates: Dict, grads: Dict) -> None:
179+
self.state['trac']['step'] += 1
180+
181+
deltas = {}
182+
183+
device = self.param_groups[0]['params'][0].device
184+
185+
h = torch.zeros((1,), device=device)
186+
for group in self.param_groups:
187+
for p in group['params']:
188+
if grads[p] is None:
189+
continue
190+
191+
theta_ref = self.state['trac'][p]
192+
update = updates[p]
193+
194+
deltas[p] = (update - theta_ref) / torch.sum(self.state['trac']['s']).add_(self.eps)
195+
update.neg_().add_(p)
196+
197+
grad, delta = grads[p], deltas[p]
198+
199+
product = torch.dot(delta.flatten(), grad.flatten())
200+
h.add_(product)
201+
202+
delta.add_(update)
203+
204+
s = self.state['trac']['s']
205+
betas = self.state['trac']['betas']
206+
variance = self.state['trac']['variance']
207+
sigma = self.state['trac']['sigma']
208+
209+
variance.mul_(betas.pow(2)).add_(h.pow(2))
210+
sigma.mul_(betas).sub_(h)
211+
212+
s_term = self.erf_imag(sigma / (2.0 * variance).sqrt_().add_(self.eps))
213+
s_term.mul_(self.f_term)
214+
s.copy_(s_term)
215+
216+
scale = max(torch.sum(s), 0.0)
217+
218+
for group in self.param_groups:
219+
for p in group['params']:
220+
if grads[p] is None:
221+
continue
222+
223+
delta = deltas[p]
224+
delta.mul_(scale).add_(self.state['trac'][p])
225+
226+
p.copy_(delta)
227+
228+
@torch.no_grad()
229+
def step(self, closure: CLOSURE = None) -> LOSS:
230+
# TODO: backup is first to get the delta of param and grad, but it does not work.
231+
with torch.enable_grad():
232+
loss = self.optimizer.step(closure)
233+
234+
updates, grads = self.backup_params_and_grads()
235+
236+
if 'trac' not in self.state:
237+
device = self.param_groups[0]['params'][0].device
238+
239+
self.state['trac'] = {
240+
'betas': torch.tensor(self.betas, device=device),
241+
's': torch.zeros(len(self.betas), device=device),
242+
'variance': torch.zeros(len(self.betas), device=device),
243+
'sigma': torch.full((len(self.betas),), 1e-8, device=device),
244+
'step': 0,
245+
}
246+
247+
for group in self.param_groups:
248+
for p in group['params']:
249+
self.state['trac'][p] = updates[p].clone()
250+
251+
self.trac_step(updates, grads)
252+
253+
return loss

tests/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
'wsam',
8686
'pcgrad',
8787
'lookahead',
88+
'trac',
8889
]
8990

9091
SPARSE_OPTIMIZERS: List[str] = ['madgrad', 'dadaptadagrad', 'sm3']

tests/test_create_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def test_create_optimizer():
99

1010
create_optimizer(model, 'adamp', lr=1e-2, weight_decay=1e-3, use_gc=True, use_lookahead=True)
1111
create_optimizer(model, 'alig', lr=1e-2, use_lookahead=True)
12+
create_optimizer(model, 'adalomo', lr=1e-2, use_lookahead=False)
1213

1314

1415
def test_bnb_optimizer():

0 commit comments

Comments
 (0)