Skip to content

Commit c39bcc9

Browse files
committed
feat(optim): WeightDecayChain
1 parent fa62fbf commit c39bcc9

File tree

2 files changed

+147
-87
lines changed

2 files changed

+147
-87
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
name='truegrad',
1111
license='BSD',
1212
description='PyTorch interface for TrueGrad-AdamW',
13-
version='3.1.1',
13+
version='4.0.0',
1414
long_description=README,
1515
url='https://github.com/clashluke/truegrad',
1616
packages=setuptools.find_packages(),

truegrad/optim.py

Lines changed: 146 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import warnings
23
from typing import Tuple, Union, List, Dict, Any, Optional
34

@@ -6,6 +7,94 @@
67
from torch.nn import Parameter
78

89

10+
class WeightDecayBase:
11+
def __init__(self):
12+
pass
13+
14+
def __call__(self, mod: torch.optim.Optimizer, p: torch.Tensor, idx: int):
15+
return p
16+
17+
18+
class WeightDecayChain:
19+
def __init__(self, *operands: WeightDecayBase):
20+
self.operands = operands
21+
22+
def __call__(self, mod: torch.optim.Optimizer):
23+
idx = 0
24+
for group in mod.param_groups:
25+
for p in group["params"]:
26+
p.data.add(functools.reduce(lambda x, f: f(mod, x, idx), self.operands, p),
27+
alpha=-group["lr"] * group["weight_decay"])
28+
idx += 1
29+
30+
31+
class LpWeightDecay(WeightDecayBase):
32+
def __init__(self, power: float):
33+
self.power = power
34+
35+
def __call__(self, mod: torch.optim.Optimizer, p: Tensor, idx: int):
36+
return p.abs().pow(self.power) * p.sign()
37+
38+
39+
class L1WeightDecay(LpWeightDecay):
40+
def __init__(self):
41+
super().__init__(0)
42+
43+
44+
class L2WeightDecay(LpWeightDecay):
45+
def __init__(self):
46+
super().__init__(1)
47+
48+
49+
def _param_iterator(mod: torch.optim.Optimizer):
50+
yield from (p.detach().clone() for group in mod.param_groups for p in group["params"])
51+
52+
53+
class WeightDecayToValue(WeightDecayBase):
54+
def __init__(self):
55+
self.target_values: List[Tensor] = ...
56+
self.global_step = 0
57+
58+
def _on_step_start(self, mod: torch.optim.Optimizer):
59+
pass
60+
61+
def _on_global_start(self, mod: torch.optim.Optimizer):
62+
pass
63+
64+
def _preprocess(self, target: Tensor):
65+
return target
66+
67+
def __call__(self, mod: torch.optim.Optimizer, p: Tensor, idx: int):
68+
if idx == 0:
69+
if self.global_step == 0:
70+
self._on_global_start(mod)
71+
self._on_step_start(mod)
72+
self.global_step += 1
73+
return p - self._preprocess(self.target_values[idx])
74+
75+
76+
class WeightDecayToInit(WeightDecayToValue):
77+
def _on_global_start(self, mod: torch.optim.Optimizer):
78+
self.target_values = list(_param_iterator(mod))
79+
80+
81+
class WeightDecayToEMA(WeightDecayToInit):
82+
def __init__(self, beta: float = 0.999):
83+
super().__init__()
84+
self.beta = beta
85+
86+
def _on_global_start(self, mod: torch.optim.Optimizer):
87+
self.target_values = [torch.zeros_like(x) for x in _param_iterator(mod)]
88+
89+
def _on_step_start(self, mod: torch.optim.Optimizer):
90+
self.global_step += 1
91+
for v, p in zip(self.target_values, _param_iterator(mod)):
92+
v.mul_(self.beta).add_(p, alpha=1 - self.beta)
93+
94+
def _preprocess(self, target: Tensor):
95+
return target / (1 - self.beta ** self.global_step)
96+
97+
998
def ema_(base: Tensor, update: Tensor, beta: float, step: Optional[int] = None):
1099
base.mul_(beta).add_(update, alpha=1 - beta)
11100
if step is None:
@@ -31,12 +120,18 @@ def decay_weight_(state: Dict[str, Any], param: torch.nn.Parameter, group: Dict[
31120
param.mul_(1 - group["weight_decay"] * group["lr"])
32121

33122

123+
def _default_decay(weight_decay_cls: Optional[WeightDecayChain]) -> WeightDecayChain:
124+
if weight_decay_cls is None:
125+
return WeightDecayChain(L2WeightDecay())
126+
return weight_decay_cls
127+
128+
34129
class OptimizerOptimizer(torch.optim.Optimizer):
35130
def __init__(self, params, inner_optimizer: torch.optim.Optimizer, learning_rate_learning_rate: float = 1,
36-
weight_decay: float = 0, decay_to_init: bool = False):
37-
self.learning_rate_learning_rate = learning_rate_learning_rate
38-
131+
weight_decay: float = 0, weight_decay_cls: Optional[WeightDecayChain] = None):
39132
self.inner_optimizer = inner_optimizer
133+
self.learning_rate_learning_rate = learning_rate_learning_rate
134+
self.weight_decay_cls = _default_decay(weight_decay_cls)
40135
param_groups = self.inner_optimizer.param_groups
41136
self.inner_optimizer.param_groups = []
42137
for group in param_groups:
@@ -45,14 +140,16 @@ def __init__(self, params, inner_optimizer: torch.optim.Optimizer, learning_rate
45140
group["params"] = [param]
46141
self.inner_optimizer.param_groups.append(group)
47142

48-
super(OptimizerOptimizer, self).__init__(params, {"weight_decay": weight_decay, "decay_to_init": decay_to_init})
143+
super(OptimizerOptimizer, self).__init__(params, {"weight_decay": weight_decay})
49144

50145
@torch.no_grad()
51146
def step(self, closure=None):
52147
loss = None
53148
if closure is not None:
54149
loss = closure()
55150

151+
self.weight_decay_cls(self)
152+
56153
for group in self.param_groups:
57154
for p in group['params']:
58155
state = self.state[p]
@@ -80,10 +177,11 @@ def step(self, closure=None):
80177

81178

82179
class Sign(torch.optim.Optimizer):
83-
def __init__(self, params, base: torch.optim.Optimizer, lr: float = 1, weight_decay: float = 0,
84-
decay_to_init: bool = False, eps: float = 1e-12, graft_to_self: bool = True):
85-
super().__init__(params, {"weight_decay": weight_decay, "decay_to_init": decay_to_init, "lr": lr, "eps": eps,
86-
"graft_to_self": graft_to_self})
180+
def __init__(self, params, base: torch.optim.Optimizer, lr: float = 1, weight_decay: float = 0, eps: float = 1e-12,
181+
graft_to_self: bool = True, weight_decay_cls: Optional[WeightDecayChain] = None):
182+
self.weight_decay_cls = _default_decay(weight_decay_cls)
183+
184+
super().__init__(params, {"weight_decay": weight_decay, "lr": lr, "eps": eps, "graft_to_self": graft_to_self})
87185
self.base = base
88186

89187
@torch.no_grad()
@@ -94,14 +192,8 @@ def step(self, closure=None):
94192
with torch.enable_grad():
95193
loss = closure()
96194

97-
params_flat = []
98-
for group in self.param_groups:
99-
for p in group["params"]:
100-
params_flat.append(p)
101-
decay_weight_(self.state[p], p, group)
102-
103-
params_flat = [torch.clone(p.detach()) for p in params_flat]
104-
195+
self.weight_decay_cls(self)
196+
params_flat = list(_param_iterator(self))
105197
self.base.step()
106198

107199
for group in self.param_groups:
@@ -150,10 +242,12 @@ class Graft(torch.optim.Optimizer):
150242
"""
151243

152244
def __init__(self, params, magnitude: torch.optim.Optimizer, direction: torch.optim.Optimizer,
153-
weight_decay: float = 0, decay_to_init: bool = False, eps: float = 1e-12, lr: float = 1):
154-
super().__init__(params, {"weight_decay": weight_decay, "decay_to_init": decay_to_init, "lr": lr, "eps": eps})
245+
weight_decay: float = 0, eps: float = 1e-12, lr: float = 1,
246+
weight_decay_cls: Optional[WeightDecayChain] = None):
247+
super().__init__(params, {"weight_decay": weight_decay, "lr": lr, "eps": eps})
155248
self.magnitude = magnitude
156249
self.direction = direction
250+
self.weight_decay_cls = _default_decay(weight_decay_cls)
157251

158252
@torch.no_grad()
159253
def step(self, closure=None):
@@ -163,13 +257,8 @@ def step(self, closure=None):
163257
with torch.enable_grad():
164258
loss = closure()
165259

166-
params_flat = []
167-
for group in self.param_groups:
168-
for p in group["params"]:
169-
params_flat.append(p)
170-
decay_weight_(self.state[p], p, group)
171-
172-
original_params = [torch.clone(p.detach()) for p in params_flat]
260+
self.weight_decay_cls(self)
261+
original_params = list(_param_iterator(self))
173262

174263
self.magnitude.step()
175264
magnitudes_flat = []
@@ -194,21 +283,16 @@ class TrueGrad(torch.optim.Optimizer):
194283
base_statistics: List[str] = []
195284
shared_statistics: List[str] = []
196285

197-
def __init__(self, params, lr: float = 1e-3,
198-
betas: List[float] = (),
199-
eps: float = 1e-12,
200-
weight_decay: float = 1e-2,
201-
graft: bool = True,
202-
decay_to_init: bool = False,
203-
default_to_baseline: bool = False,
204-
enforce_baseline: bool = False):
286+
def __init__(self, params, lr: float = 1e-3, betas: List[float] = (), eps: float = 1e-12,
287+
weight_decay: float = 1e-2, graft: bool = True, default_to_baseline: bool = False,
288+
enforce_baseline: bool = False, weight_decay_cls: Optional[WeightDecayChain] = None):
205289
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, graft=graft,
206-
decay_to_init=decay_to_init, default_to_baseline=default_to_baseline,
207-
enforce_baseline=enforce_baseline)
290+
default_to_baseline=default_to_baseline, enforce_baseline=enforce_baseline)
208291
super(TrueGrad, self).__init__(params, defaults)
292+
self.weight_decay_cls = _default_decay(weight_decay_cls)
209293

210-
def _inner(self, step: int, p: Parameter, group: Dict[str, Any], **kwargs: Tensor
211-
) -> Tuple[Optional[Tensor], Optional[Tensor], float]:
294+
def _inner(self, step: int, p: Parameter, group: Dict[str, Any], **kwargs: Tensor) -> Tuple[
295+
Optional[Tensor], Optional[Tensor], float]:
212296
raise NotImplementedError
213297

214298
@torch.no_grad()
@@ -245,12 +329,7 @@ def step(self, closure=None):
245329
step_t = state['step']
246330
step_t += 1
247331

248-
# Perform stepweight decay
249-
decay = group['lr'] * group['weight_decay']
250-
if group["decay_to_init"]:
251-
p.add_(state["init"] - p, alpha=decay)
252-
else:
253-
p.mul_(1 - decay)
332+
self.weight_decay_cls(self)
254333

255334
step = step_t.item()
256335

@@ -275,28 +354,18 @@ class TGAdamW(TrueGrad):
275354

276355
def __init__(self, params, lr: float = 1e-3,
277356
betas: Union[Tuple[float, float], Tuple[float, float, float]] = (0.9, 0.999, 0.999),
278-
eps: float = 1e-12,
279-
weight_decay: float = 1e-2,
280-
graft: bool = True,
281-
decay_to_init: bool = False,
282-
default_to_adam: bool = None,
283-
default_to_baseline: bool = None,
284-
enforce_baseline: bool = False):
285-
if default_to_baseline is None:
286-
default_to_baseline = default_to_adam
287-
elif default_to_adam is not None:
288-
raise ValueError("Can't set both default_to_baseline and default_to_adam, as both map to the same argument")
289-
if default_to_adam is not None:
290-
warnings.warn("default_to_adam is deprecated and will be replaced by default_to_baseline in April 2023")
357+
eps: float = 1e-12, weight_decay: float = 1e-2, graft: bool = True,
358+
default_to_baseline: bool = None, enforce_baseline: bool = False,
359+
weight_decay_cls: Optional[WeightDecayChain] = None):
291360
if default_to_baseline is None:
292361
default_to_baseline = False
293362
super().__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, graft=graft,
294-
decay_to_init=decay_to_init, default_to_baseline=default_to_baseline,
295-
enforce_baseline=enforce_baseline)
363+
default_to_baseline=default_to_baseline, enforce_baseline=enforce_baseline,
364+
weight_decay_cls=weight_decay_cls)
296365

297366
def _inner(self, step: int, p: Parameter, group: Dict[str, Any], exp_avg: Tensor,
298-
exp_avg_sq: Optional[Tensor] = None, exp_avg_true_sq: Optional[Tensor] = None
299-
) -> Tuple[Optional[Tensor], Optional[Tensor], float]:
367+
exp_avg_sq: Optional[Tensor] = None, exp_avg_true_sq: Optional[Tensor] = None) -> Tuple[
368+
Optional[Tensor], Optional[Tensor], float]:
300369
if len(group["betas"]) == 2:
301370
(beta1, beta2), (_, beta3) = group["betas"], group["betas"]
302371
else:
@@ -317,21 +386,17 @@ class TGLaProp(TrueGrad):
317386
base_statistics: List[str] = ["exp_avg", "exp_avg_sq"]
318387

319388
def __init__(self, params, lr: float = 1e-3,
320-
betas: Union[Tuple[float, float], Tuple[float, float, float, float]] = (0.9, 0.99),
321-
eps: float = 1e-12,
322-
weight_decay: float = 1e-2,
323-
graft: bool = True,
324-
decay_to_init: bool = False,
325-
default_to_baseline: bool = False,
326-
enforce_baseline: bool = False):
389+
betas: Union[Tuple[float, float], Tuple[float, float, float, float]] = (0.9, 0.99), eps: float = 1e-12,
390+
weight_decay: float = 1e-2, graft: bool = True, decay_to_init: bool = False,
391+
default_to_baseline: bool = False, enforce_baseline: bool = False,
392+
weight_decay_cls: Optional[WeightDecayChain] = None):
327393
super().__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, graft=graft,
328-
decay_to_init=decay_to_init, default_to_baseline=default_to_baseline,
329-
enforce_baseline=enforce_baseline)
394+
default_to_baseline=default_to_baseline, enforce_baseline=enforce_baseline,
395+
weight_decay_cls=weight_decay_cls)
330396

331-
def _inner(self, step: int, p: Parameter, group: Dict[str, Any],
332-
exp_avg: Optional[Tensor] = None, exp_avg_sq: Optional[Tensor] = None,
333-
exp_avg_true: Optional[Tensor] = None, exp_avg_true_sq: Optional[Tensor] = None
334-
) -> Tuple[Optional[Tensor], Optional[Tensor], float]:
397+
def _inner(self, step: int, p: Parameter, group: Dict[str, Any], exp_avg: Optional[Tensor] = None,
398+
exp_avg_sq: Optional[Tensor] = None, exp_avg_true: Optional[Tensor] = None,
399+
exp_avg_true_sq: Optional[Tensor] = None) -> Tuple[Optional[Tensor], Optional[Tensor], float]:
335400
if len(group["betas"]) == 2:
336401
(beta1, beta2), (beta3, beta4) = group["betas"], group["betas"]
337402
else:
@@ -362,21 +427,16 @@ class TGRMSProp(TrueGrad):
362427
true_statistics: List[str] = ["exp_avg_true_sq"]
363428
base_statistics: List[str] = ["exp_avg_sq"]
364429

365-
def __init__(self, params, lr: float = 1e-3,
366-
betas: Union[float, Tuple[float], Tuple[float, float]] = (0.9,),
367-
eps: float = 1e-12,
368-
weight_decay: float = 1e-2,
369-
graft: bool = True,
370-
decay_to_init: bool = False,
371-
default_to_baseline: bool = False,
372-
enforce_baseline: bool = False):
430+
def __init__(self, params, lr: float = 1e-3, betas: Union[float, Tuple[float], Tuple[float, float]] = (0.9,),
431+
eps: float = 1e-12, weight_decay: float = 1e-2, graft: bool = True,
432+
default_to_baseline: bool = False, enforce_baseline: bool = False,
433+
weight_decay_cls: Optional[WeightDecayChain] = None):
373434
super().__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, graft=graft,
374-
decay_to_init=decay_to_init, default_to_baseline=default_to_baseline,
375-
enforce_baseline=enforce_baseline)
435+
default_to_baseline=default_to_baseline, enforce_baseline=enforce_baseline,
436+
weight_decay_cls=weight_decay_cls)
376437

377-
def _inner(self, step: int, p: Parameter, group: Dict[str, Any],
378-
exp_avg_sq: Optional[Tensor] = None, exp_avg_true_sq: Optional[Tensor] = None
379-
) -> Tuple[Optional[Tensor], Optional[Tensor], float]:
438+
def _inner(self, step: int, p: Parameter, group: Dict[str, Any], exp_avg_sq: Optional[Tensor] = None,
439+
exp_avg_true_sq: Optional[Tensor] = None) -> Tuple[Optional[Tensor], Optional[Tensor], float]:
380440
if isinstance(group["betas"], float):
381441
beta1 = beta2 = group["betas"]
382442
elif len(group["betas"]) == 1:

0 commit comments

Comments
 (0)