Skip to content

Commit a269e41

Browse files
committed
fix: key error 'z'
1 parent 323e6c6 commit a269e41

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

pytorch_optimizer/optimizer/schedulefree.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -505,13 +505,7 @@ def __init__(
505505
self._optimizer_step_post_hooks: Dict[int, Callable] = {}
506506

507507
self.state: STATE = defaultdict(dict)
508-
509-
for group in self.param_groups:
510-
for p in group['params']:
511-
state = self.state[p]
512-
state['z'] = torch.clone(p)
513-
514-
self.defaults = self.optimizer.defaults
508+
self.defaults: DEFAULTS = self.optimizer.defaults
515509

516510
def __str__(self) -> str:
517511
return 'ScheduleFree'
@@ -594,6 +588,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
594588

595589
state = self.state[p]
596590

591+
if 'z' not in state:
592+
state['z'] = p.clone()
593+
597594
z = state['z']
598595

599596
self.apply_weight_decay(
@@ -633,7 +630,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
633630
weight: float = (group['step'] ** group['lr']) * (lr_max ** self.weight_lr_power) # fmt: skip
634631
weight_sum = group['weight_sum'] = group.get('weight_sum', 0.0) + weight
635632

636-
ckeckpoint: float = weight / weight_sum if weight_sum != 0.0 else 0.0
633+
checkpoint: float = weight / weight_sum if weight_sum != 0.0 else 0.0
637634

638635
for p in group['params']:
639636
if p.grad is None:
@@ -645,7 +642,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
645642

646643
self.swap(z, p)
647644

648-
p.lerp_(end=z, weight=ckeckpoint)
645+
p.lerp_(end=z, weight=checkpoint)
649646

650647
p.lerp_(end=state['z'], weight=1.0 - self.momentum)
651648

0 commit comments

Comments
 (0)