Skip to content

Commit 6da1cd2

Browse files
committed
feat(optimizer): add TGLaProp
1 parent 010020b commit 6da1cd2

File tree

1 file changed

+134
-41
lines changed

1 file changed

+134
-41
lines changed

truegrad/optim.py

Lines changed: 134 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,51 @@
1-
from typing import Tuple, Union
1+
import enum
2+
import warnings
3+
from typing import Tuple, Union, List, Dict, Any, Optional
24

35
import torch
6+
from torch import Tensor
7+
from torch.nn import Parameter
48

59

6-
class TGAdamW(torch.optim.Optimizer):
10+
class BaseOptimizer(enum.Enum, str):
11+
adam: str = "adam"
12+
laprop: str = "laprop"
13+
14+
15+
def ema_(base: Tensor, update: Tensor, beta: float, step: int = 0):
16+
base.mul_(beta).add_(update, alpha=1 - beta)
17+
if not step:
18+
return base
19+
return base / (1 - beta ** step)
20+
21+
22+
def stable_sqrt(base: Tensor, eps: float):
23+
return base.sqrt().clamp(min=eps)
24+
25+
26+
def div_ema(base: Tensor, eps: float, base_sq: Tensor, update_sq: Tensor, beta_sq: float, step: int = 0):
27+
return base / stable_sqrt(ema_(base_sq, update_sq, beta_sq, step), eps)
28+
29+
30+
class TrueGrad(torch.optim.Optimizer):
31+
true_statistics: List[str] = []
32+
base_statistics: List[str] = []
33+
shared_statistics: List[str] = []
34+
735
def __init__(self, params, lr: float = 1e-3,
8-
betas: Union[Tuple[float, float], Tuple[float, float, float]] = (0.9, 0.999, 0.999),
36+
betas: List[float] = (),
937
eps: float = 1e-12,
1038
weight_decay: float = 1e-2,
1139
graft: bool = True,
1240
decay_to_init: bool = False,
13-
default_to_adam: bool = False):
41+
default_to_baseline: bool = False):
1442
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, graft=graft,
15-
decay_to_init=decay_to_init, default_to_adam=default_to_adam)
16-
super(TGAdamW, self).__init__(params, defaults)
43+
decay_to_init=decay_to_init, default_to_baseline=default_to_baseline)
44+
super(TrueGrad, self).__init__(params, defaults)
45+
46+
def _inner(self, step: int, p: Parameter, group: Dict[str, Any], **kwargs: Tensor
47+
) -> Tuple[Optional[Tensor], Optional[Tensor], float]:
48+
raise NotImplementedError
1749

1850
@torch.no_grad()
1951
def step(self, closure=None):
@@ -23,37 +55,30 @@ def step(self, closure=None):
2355
with torch.enable_grad():
2456
loss = closure()
2557
for group in self.param_groups:
26-
if len(group["betas"]) == 2:
27-
beta1, beta2 = group["betas"]
28-
beta3 = beta2
29-
else:
30-
beta1, beta2, beta3 = group['betas']
31-
3258
for p in group['params']:
3359
if p.grad is None:
3460
continue
35-
do_adam = not hasattr(p, "sum_grad_squared") or p.sum_grad_squared is None
36-
if not group["default_to_adam"] and do_adam:
61+
do_baseline = not hasattr(p, "sum_grad_squared") or p.sum_grad_squared is None
62+
if not group["default_to_baseline"] and do_baseline:
3763
raise ValueError(f"Parameter of shape {list(p.size())} doesn't have `sum_grad_squared` attribute. "
3864
f"Make sure to use backpack.")
3965

4066
state = self.state[p]
4167

4268
if len(state) == 0:
43-
state['step'] = torch.tensor(0.)
44-
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
45-
if not do_adam:
46-
state['exp_avg_true_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
47-
if do_adam or group["graft"]:
48-
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
69+
state['step'] = Tensor(0.)
70+
for s in self.shared_statistics:
71+
state[s] = torch.zeros_like(p, memory_format=torch.preserve_format)
72+
if not do_baseline:
73+
for s in self.true_statistics:
74+
state[s] = torch.zeros_like(p, memory_format=torch.preserve_format)
75+
if do_baseline or group["graft"]:
76+
for s in self.base_statistics:
77+
state[s] = torch.zeros_like(p, memory_format=torch.preserve_format)
4978
if group["decay_to_init"]:
5079
state["init"] = torch.clone(p.detach())
5180

52-
exp_avg = state['exp_avg']
53-
exp_avg_true_sq = state['exp_avg_true_sq']
5481
step_t = state['step']
55-
56-
# update step
5782
step_t += 1
5883

5984
# Perform stepweight decay
@@ -63,26 +88,94 @@ def step(self, closure=None):
6388
else:
6489
p.mul_(1 - decay)
6590

66-
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
67-
6891
step = step_t.item()
69-
alpha = -group['lr'] / (1 - beta1 ** step)
70-
71-
if not do_adam:
72-
exp_avg_true_sq.mul_(beta3).add_(p.sum_grad_squared, alpha=1 - beta3)
73-
p.sum_grad_squared = None
74-
denom = (exp_avg_true_sq / (1 - beta3 ** step)).sqrt().add_(group['eps'])
75-
update = exp_avg / denom
7692

77-
if group["graft"] or do_adam:
78-
exp_avg_sq = state['exp_avg_sq']
79-
exp_avg_sq.mul_(beta2).add_(p.grad.square(), alpha=1 - beta2)
80-
adam_update = exp_avg / (exp_avg_sq / (1 - beta2 ** step)).sqrt().add_(group['eps'])
93+
base_update, update, alpha = self._inner(step, p,
94+
**{k: state[k] for k in self.shared_statistics},
95+
**{k: state[k] for k in self.base_statistics},
96+
**{k: state[k] for k in self.true_statistics})
8197

82-
if group["graft"] and not do_adam:
83-
alpha = alpha * adam_update.norm() / update.norm().add_(group['eps'])
84-
elif do_adam:
85-
update = adam_update
98+
if group["graft"] and not do_baseline:
99+
alpha = alpha * base_update.norm() / update.norm().add_(group['eps'])
100+
elif do_baseline:
101+
update = base_update
86102

87103
p.add_(update, alpha=alpha)
88104
return loss
105+
106+
107+
class TGAdamW(TrueGrad):
108+
true_statistics: List[str] = ["exp_avg_true_sq"]
109+
base_statistics: List[str] = ["exp_avg_sq"]
110+
shared_statistics: List[str] = ["exp_avg"]
111+
112+
def __init__(self, params, lr: float = 1e-3,
113+
betas: Union[Tuple[float, float], Tuple[float, float, float]] = (0.9, 0.999, 0.999),
114+
eps: float = 1e-12,
115+
weight_decay: float = 1e-2,
116+
graft: bool = True,
117+
decay_to_init: bool = False,
118+
default_to_adam: bool = None,
119+
default_to_baseline: bool = None):
120+
if default_to_baseline is None:
121+
default_to_baseline = default_to_adam
122+
elif default_to_adam is not None:
123+
raise ValueError("Can't set both default_to_baseline and default_to_adam, as both map to the same argument")
124+
if default_to_adam is not None:
125+
warnings.warn("default_to_adam is deprecated and will be replaced by default_to_baseline in April 2023")
126+
if default_to_baseline is None:
127+
default_to_baseline = False
128+
super().__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, graft=graft,
129+
decay_to_init=decay_to_init, default_to_baseline=default_to_baseline)
130+
131+
def _inner(self, step: int, p: Parameter, do_baseline: bool, group: Dict[str, Any], exp_avg: Tensor,
132+
exp_avg_sq: Optional[Tensor] = None, exp_avg_true_sq: Optional[Tensor] = None
133+
) -> Tuple[Optional[Tensor], Optional[Tensor], float]:
134+
if len(group["betas"]) == 2:
135+
(beta1, beta2), (_, beta3) = group["betas"], group["betas"]
136+
else:
137+
beta1, beta2, beta3 = group['betas']
138+
139+
update, base_update, eps = None, None, group["eps"]
140+
ema_(exp_avg, p.grad, beta1)
141+
if exp_avg_true_sq is not None:
142+
update = div_ema(exp_avg, group["eps"], exp_avg_true_sq, p.sum_grad_squared, beta3, step)
143+
if exp_avg_sq is not None:
144+
base_update = div_ema(exp_avg, group["eps"], exp_avg_sq, p.grad.square(), beta2, step)
145+
146+
return base_update, update, -group['lr'] / (1 - beta1 ** step)
147+
148+
149+
class TGLaProp(TrueGrad):
150+
true_statistics: List[str] = ["exp_avg_true", "exp_avg_true_sq"]
151+
base_statistics: List[str] = ["exp_avg", "exp_avg_sq"]
152+
153+
def __init__(self, params, lr: float = 1e-3,
154+
betas: Union[Tuple[float, float], Tuple[float, float, float, float]] = (0.9, 0.99),
155+
eps: float = 1e-12,
156+
weight_decay: float = 1e-2,
157+
graft: bool = True,
158+
decay_to_init: bool = False,
159+
default_to_baseline: bool = False):
160+
super().__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, graft=graft,
161+
decay_to_init=decay_to_init, default_to_baseline=default_to_baseline)
162+
163+
def _inner(self, step: int, p: Parameter, do_baseline: bool, group: Dict[str, Any],
164+
exp_avg: Optional[Tensor] = None, exp_avg_sq: Optional[Tensor] = None,
165+
exp_avg_true: Optional[Tensor] = None, exp_avg_true_sq: Optional[Tensor] = None
166+
) -> Tuple[Optional[Tensor], Optional[Tensor], float]:
167+
if len(group["betas"]) == 2:
168+
(beta1, beta2), (beta3, beta4) = group["betas"], group["betas"]
169+
else:
170+
beta1, beta2, beta3, beta4 = group['betas']
171+
172+
update, base_update, alpha, eps = None, None, 1, group["eps"]
173+
if exp_avg_true_sq is not None:
174+
update = ema_(exp_avg_true, div_ema(p.grad, eps, exp_avg_true_sq, p.sum_grad_squared, beta4, step), beta3)
175+
alpha = -group['lr'] / (1 - beta3 ** step)
176+
177+
if exp_avg_sq is not None:
178+
base_update = ema_(exp_avg, div_ema(p.grad, eps, exp_avg_sq, p.grad.square(), beta2, step), beta1)
179+
alpha = -group['lr'] / (1 - beta1 ** step) # if grafting, beta3 issues are "grafted" away
180+
181+
return base_update, update, alpha

0 commit comments

Comments
 (0)