Skip to content

Commit 9a43be0

Browse files
committed
fix(optim): use weight decay cls in optimizeroptimizer
1 parent 7a12055 commit 9a43be0

File tree

2 files changed

+4
-17
lines changed

2 files changed

+4
-17
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
name='truegrad',
1111
license='BSD',
1212
description='PyTorch interface for TrueGrad-AdamW',
13-
version='4.0.1',
13+
version='4.0.2',
1414
long_description=README,
1515
url='https://github.com/clashluke/truegrad',
1616
packages=setuptools.find_packages(),

truegrad/optim.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -115,16 +115,6 @@ def div_ema(base: Tensor, eps: float, base_sq: Tensor, update_sq: Tensor, beta_s
115115
return base / stable_sqrt(ema_(base_sq, update_sq, beta_sq, step), eps)
116116

117117

118-
def decay_weight_(state: Dict[str, Any], param: torch.nn.Parameter, group: Dict[str, Any]):
119-
if group["decay_to_init"]:
120-
if "param_at_init" not in state:
121-
state["param_at_init"] = torch.clone(param.detach())
122-
else:
123-
param.add_(state["param_at_init"] - param, alpha=group["weight_decay"] * group["lr"])
124-
else:
125-
param.mul_(1 - group["weight_decay"] * group["lr"])
126-
127-
128118
def _default_decay(weight_decay_cls: Optional[WeightDecayChain]) -> WeightDecayChain:
129119
if weight_decay_cls is None:
130120
return WeightDecayChain(L2WeightDecay())
@@ -153,16 +143,15 @@ def step(self, closure=None):
153143
if closure is not None:
154144
loss = closure()
155145

156-
self.weight_decay_cls(self)
157-
158146
for group in self.param_groups:
159147
for p in group['params']:
160148
state = self.state[p]
161149
if "lr" in state:
162150
group["lr"] = state["lr"]
163-
decay_weight_(state, p, group)
164151
state["param"] = torch.clone(p.detach())
165152

153+
self.weight_decay_cls(self)
154+
166155
self.inner_optimizer.step()
167156

168157
for group in self.inner_optimizer.param_groups:
@@ -330,8 +319,6 @@ def step(self, closure=None):
330319
if do_base or group["graft"]:
331320
for s in self.base_statistics:
332321
state[s] = torch.zeros_like(p, memory_format=torch.preserve_format)
333-
if group["decay_to_init"]:
334-
state["init"] = torch.clone(p.detach())
335322

336323
step_t = state['step']
337324
step_t += 1
@@ -393,7 +380,7 @@ class TGLaProp(TrueGrad):
393380

394381
def __init__(self, params, lr: float = 1e-3,
395382
betas: Union[Tuple[float, float], Tuple[float, float, float, float]] = (0.9, 0.99), eps: float = 1e-12,
396-
weight_decay: float = 1e-2, graft: bool = True, decay_to_init: bool = False,
383+
weight_decay: float = 1e-2, graft: bool = True,
397384
default_to_baseline: bool = False, enforce_baseline: bool = False,
398385
weight_decay_cls: Optional[WeightDecayChain] = None):
399386
super().__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, graft=graft,

0 commit comments

Comments
 (0)