@@ -505,13 +505,7 @@ def __init__(
505505 self ._optimizer_step_post_hooks : Dict [int , Callable ] = {}
506506
507507 self .state : STATE = defaultdict (dict )
508-
509- for group in self .param_groups :
510- for p in group ['params' ]:
511- state = self .state [p ]
512- state ['z' ] = torch .clone (p )
513-
514- self .defaults = self .optimizer .defaults
508+ self .defaults : DEFAULTS = self .optimizer .defaults
515509
516510 def __str__ (self ) -> str :
517511 return 'ScheduleFree'
@@ -594,6 +588,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
594588
595589 state = self .state [p ]
596590
591+ if 'z' not in state :
592+ state ['z' ] = p .clone ()
593+
597594 z = state ['z' ]
598595
599596 self .apply_weight_decay (
@@ -633,7 +630,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
633630 weight : float = (group ['step' ] ** group ['lr' ]) * (lr_max ** self .weight_lr_power ) # fmt: skip
634631 weight_sum = group ['weight_sum' ] = group .get ('weight_sum' , 0.0 ) + weight
635632
636- ckeckpoint : float = weight / weight_sum if weight_sum != 0.0 else 0.0
633+ checkpoint : float = weight / weight_sum if weight_sum != 0.0 else 0.0
637634
638635 for p in group ['params' ]:
639636 if p .grad is None :
@@ -645,7 +642,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
645642
646643 self .swap (z , p )
647644
648- p .lerp_ (end = z , weight = ckeckpoint )
645+ p .lerp_ (end = z , weight = checkpoint )
649646
650647 p .lerp_ (end = state ['z' ], weight = 1.0 - self .momentum )
651648
0 commit comments