1- from collections import defaultdict
2- from typing import Dict
1+ from collections import defaultdict , OrderedDict
2+ from typing import Dict , Callable
33
44import torch
5+ from torch .optim import Optimizer
56
67from pytorch_optimizer .base .optimizer import BaseOptimizer
7- from pytorch_optimizer .base .types import CLOSURE , LOSS , OPTIMIZER , STATE
8+ from pytorch_optimizer .base .types import CLOSURE , LOSS , OPTIMIZER , STATE , DEFAULTS
89
910
10- class Lookahead (BaseOptimizer ):
11+ class Lookahead (BaseOptimizer , Optimizer ):
1112 r"""k steps forward, 1 step back.
1213
1314 :param optimizer: OPTIMIZER. base optimizer.
@@ -33,6 +34,7 @@ def __init__(
3334
3435 self .optimizer = optimizer
3536 self .param_groups = self .optimizer .param_groups
37+
3638 self .state : STATE = defaultdict (dict )
3739
3840 for group in self .param_groups :
@@ -45,6 +47,15 @@ def __init__(
4547 state ['slow_params' ].copy_ (p )
4648 if self .pullback_momentum == 'pullback' :
4749 state ['slow_momentum' ] = torch .zeros_like (p )
50+
51+ # Instead of calling super().__init__, we set the attributes ourselves
52+ self ._optimizer_step_pre_hooks : Dict [int , Callable ] = OrderedDict ()
53+ self ._optimizer_step_post_hooks : Dict [int , Callable ] = OrderedDict ()
54+ self .defaults : DEFAULTS = {
55+ ** optimizer .defaults ,
56+ ** dict (lookahead_alpha = alpha , lookahead_k = k , lookahead_pullback_momentum = pullback_momentum ),
57+ }
58+
4859
4960 def __getstate__ (self ):
5061 return {
0 commit comments