Skip to content

Commit 4f4d359

Browse files
committed
update: TRAC optimizer
1 parent fe3b5d9 commit 4f4d359

File tree

2 files changed

+32
-24
lines changed

2 files changed

+32
-24
lines changed

pytorch_optimizer/optimizer/trac.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from collections import defaultdict
21
from typing import Callable, Dict, List, Tuple
32

43
import torch
54
from torch import nn
65

76
from pytorch_optimizer.base.optimizer import BaseOptimizer
8-
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER, STATE
7+
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER
98

109

1110
def polyval(x: torch.Tensor, coef: torch.Tensor) -> torch.Tensor:
@@ -119,8 +118,9 @@ def __init__(
119118
self.s_prev = s_prev
120119
self.eps = eps
121120

121+
self.f_term = self.s_prev / self.erf_imag(1.0 / torch.sqrt(torch.tensor(2.0)))
122+
122123
self.optimizer = optimizer
123-
self.state: STATE = defaultdict(dict)
124124
self.defaults: DEFAULTS = optimizer.defaults
125125

126126
def __str__(self) -> str:
@@ -130,6 +130,10 @@ def __str__(self) -> str:
130130
def param_groups(self):
131131
return self.optimizer.param_groups
132132

133+
@property
134+
def state(self):
135+
return self.optimizer.state
136+
133137
@torch.no_grad()
134138
def reset(self):
135139
device = self.param_groups[0]['params'][0].device
@@ -172,7 +176,7 @@ def backup_params_and_grads(self) -> Tuple[Dict, Dict]:
172176

173177
@torch.no_grad()
174178
def trac_step(self, updates: Dict, grads: Dict) -> None:
175-
self.state['step'] += 1
179+
self.state['trac']['step'] += 1
176180

177181
deltas = {}
178182

@@ -181,13 +185,13 @@ def trac_step(self, updates: Dict, grads: Dict) -> None:
181185
h = torch.zeros((1,), device=device)
182186
for group in self.param_groups:
183187
for p in group['params']:
184-
if p.grad is None:
188+
if grads[p] is None:
185189
continue
186190

187-
theta_ref = self.state[p]
191+
theta_ref = self.state['trac'][p]
188192
update = updates[p]
189193

190-
deltas[p] = (update - theta_ref) / (torch.sum(self.state['s']) + self.eps)
194+
deltas[p] = (update - theta_ref) / torch.sum(self.state['trac']['s']).add_(self.eps)
191195
update.neg_().add_(p)
192196

193197
grad, delta = grads[p], deltas[p]
@@ -197,36 +201,42 @@ def trac_step(self, updates: Dict, grads: Dict) -> None:
197201

198202
delta.add_(update)
199203

200-
s = self.state['s']
201-
betas = self.state['betas']
202-
variance = self.state['variance']
203-
sigma = self.state['sigma']
204+
s = self.state['trac']['s']
205+
betas = self.state['trac']['betas']
206+
variance = self.state['trac']['variance']
207+
sigma = self.state['trac']['sigma']
204208

205209
variance.mul_(betas.pow(2)).add_(h.pow(2))
206210
sigma.mul_(betas).sub_(h)
207211

208-
f_term = self.s_prev / self.erf_imag(1.0 / torch.sqrt(torch.tensor(2.0)))
209-
s_term = self.erf_imag(sigma / (torch.sqrt(torch.tensor(2.0)) * variance.sqrt() + self.eps))
210-
s.copy_(f_term * s_term)
212+
s_term = self.erf_imag(sigma / (2.0 * variance).sqrt_().add_(self.eps))
213+
s_term.mul_(self.f_term)
214+
s.copy_(s_term)
215+
216+
scale = max(torch.sum(s), 0.0)
211217

212218
for group in self.param_groups:
213219
for p in group['params']:
214220
if grads[p] is None:
215221
continue
216222

217-
p.copy_(self.state[p] + deltas[p] * max(torch.sum(s), 0.0))
223+
delta = deltas[p]
224+
delta.mul_(scale).add_(self.state['trac'][p])
225+
226+
p.copy_(delta)
218227

219228
@torch.no_grad()
220229
def step(self, closure: CLOSURE = None) -> LOSS:
230+
# TODO: backup is first to get the delta of param and grad, but it does not work.
221231
with torch.enable_grad():
222232
loss = self.optimizer.step(closure)
223233

224234
updates, grads = self.backup_params_and_grads()
225235

226-
if len(self.state) == 0:
227-
device = updates[next(iter(updates.keys()))].device
236+
if 'trac' not in self.state:
237+
device = self.param_groups[0]['params'][0].device
228238

229-
self.state = {
239+
self.state['trac'] = {
230240
'betas': torch.tensor(self.betas, device=device),
231241
's': torch.zeros(len(self.betas), device=device),
232242
'variance': torch.zeros(len(self.betas), device=device),
@@ -236,7 +246,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
236246

237247
for group in self.param_groups:
238248
for p in group['params']:
239-
self.state[p] = updates[p].clone()
249+
self.state['trac'][p] = updates[p].clone()
240250

241251
self.trac_step(updates, grads)
242252

tests/test_optimizers.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -676,18 +676,16 @@ def test_trac_optimizer(environment):
676676
optimizer = TRAC(load_optimizer('adamw')(model.parameters(), lr=1e0))
677677

678678
init_loss, loss = np.inf, np.inf
679-
for _ in range(5):
680-
optimizer.zero_grad()
681-
682-
y_pred = model(x_data)
683-
loss = loss_fn(y_pred, y_data)
679+
for _ in range(3):
680+
loss = loss_fn(model(x_data), y_data)
684681

685682
if init_loss == np.inf:
686683
init_loss = loss
687684

688685
loss.backward()
689686

690687
optimizer.step()
688+
optimizer.zero_grad()
691689

692690
assert tensor_to_numpy(init_loss) > 2.0 * tensor_to_numpy(loss)
693691

0 commit comments

Comments
 (0)