Skip to content

Commit 7a12055

Browse files
committed
fix(optim): add missing variable in Graft
1 parent d9b50a8 commit 7a12055

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
name='truegrad',
1111
license='BSD',
1212
description='PyTorch interface for TrueGrad-AdamW',
13-
version='4.0.0',
13+
version='4.0.1',
1414
long_description=README,
1515
url='https://github.com/clashluke/truegrad',
1616
packages=setuptools.find_packages(),

truegrad/optim.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import functools
2-
import warnings
32
from typing import Tuple, Union, List, Dict, Any, Optional
43

54
import torch
@@ -30,6 +29,7 @@ def __call__(self, mod: torch.optim.Optimizer):
3029

3130
class LpWeightDecay(WeightDecayBase):
3231
def __init__(self, power: float):
32+
super().__init__()
3333
self.power = power
3434

3535
def __call__(self, mod: torch.optim.Optimizer, p: Tensor, idx: int):
@@ -46,12 +46,17 @@ def __init__(self):
4646
super().__init__(1)
4747

4848

49-
def _param_iterator(mod: torch.optim.Optimizer):
50-
yield from (p.detach().clone() for group in mod.param_groups for p in group["params"])
49+
def _detach(x: Tensor) -> Tensor:
50+
return x.detach().clone()
51+
52+
53+
def _param_iterator(mod: torch.optim.Optimizer, fn=_detach):
54+
yield from (fn(p) for group in mod.param_groups for p in group["params"])
5155

5256

5357
class WeightDecayToValue(WeightDecayBase):
5458
def __init__(self):
59+
super().__init__()
5560
self.target_values: List[Tensor] = ...
5661
self.global_step = 0
5762

@@ -261,6 +266,8 @@ def step(self, closure=None):
261266
original_params = list(_param_iterator(self))
262267

263268
self.magnitude.step()
269+
params_flat = list(_param_iterator(self, lambda x: x))
270+
264271
magnitudes_flat = []
265272
for o, p in zip(original_params, params_flat):
266273
magnitudes_flat.append(torch.norm(o.double() - p.double()))

0 commit comments

Comments
 (0)