Skip to content

Commit 3ef7c75

Browse files
authored
Merge pull request #367 from kozistr/fix/schedulefree-wrapper
[Fix] ScheduleFreeWrapper
2 parents aca2ef2 + 9efb597 commit 3ef7c75

File tree

3 files changed

+17
-11
lines changed

3 files changed

+17
-11
lines changed

docs/changelogs/v3.5.0.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@
2525

2626
* bias_correction2 in ScheduleFreeRAdam optimizer. (#354)
2727
* potential bug in SPAM optimizer. (#365)
28+
* initialize the `z` state within the `step()` of the ScheduleFreeWrapper. (#363, #366)

pytorch_optimizer/optimizer/schedulefree.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -505,13 +505,7 @@ def __init__(
505505
self._optimizer_step_post_hooks: Dict[int, Callable] = {}
506506

507507
self.state: STATE = defaultdict(dict)
508-
509-
for group in self.param_groups:
510-
for p in group['params']:
511-
state = self.state[p]
512-
state['z'] = torch.clone(p)
513-
514-
self.defaults = self.optimizer.defaults
508+
self.defaults: DEFAULTS = self.optimizer.defaults
515509

516510
def __str__(self) -> str:
517511
return 'ScheduleFree'
@@ -594,6 +588,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
594588

595589
state = self.state[p]
596590

591+
if 'z' not in state:
592+
state['z'] = p.clone()
593+
597594
z = state['z']
598595

599596
self.apply_weight_decay(
@@ -633,7 +630,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
633630
weight: float = (group['step'] ** group['lr']) * (lr_max ** self.weight_lr_power) # fmt: skip
634631
weight_sum = group['weight_sum'] = group.get('weight_sum', 0.0) + weight
635632

636-
ckeckpoint: float = weight / weight_sum if weight_sum != 0.0 else 0.0
633+
checkpoint: float = weight / weight_sum if weight_sum != 0.0 else 0.0
637634

638635
for p in group['params']:
639636
if p.grad is None:
@@ -645,7 +642,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
645642

646643
self.swap(z, p)
647644

648-
p.lerp_(end=z, weight=ckeckpoint)
645+
p.lerp_(end=z, weight=checkpoint)
649646

650647
p.lerp_(end=state['z'], weight=1.0 - self.momentum)
651648

tests/test_optimizers.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,9 +1026,17 @@ def test_schedulefree_wrapper():
10261026
_ = optimizer.__getstate__()
10271027
_ = optimizer.param_groups
10281028

1029-
optimizer.load_state_dict(optimizer.state_dict())
1029+
optimizer.step()
1030+
1031+
backup_state = optimizer.state_dict()
1032+
1033+
optimizer = ScheduleFreeWrapper(load_optimizer('adamw')(model.parameters(), lr=1e-3, weight_decay=1e-3))
1034+
optimizer.reset()
1035+
optimizer.zero_grad()
1036+
optimizer.train()
1037+
1038+
optimizer.load_state_dict(backup_state)
10301039

1031-
optimizer.optimizer.step()
10321040
optimizer.step()
10331041

10341042
optimizer.eval()

0 commit comments

Comments
 (0)