1- from collections import defaultdict , OrderedDict
2- from typing import Dict , Callable
1+ from collections import defaultdict
2+ from typing import Callable , Dict
33
44import torch
55from torch .optim import Optimizer
66
77from pytorch_optimizer .base .optimizer import BaseOptimizer
8- from pytorch_optimizer .base .types import CLOSURE , LOSS , OPTIMIZER , STATE , DEFAULTS
8+ from pytorch_optimizer .base .types import CLOSURE , DEFAULTS , LOSS , OPTIMIZER , STATE
99
1010
11- class Lookahead (BaseOptimizer , Optimizer ):
11+ class Lookahead (Optimizer , BaseOptimizer ):
1212 r"""k steps forward, 1 step back.
1313
1414 :param optimizer: OPTIMIZER. base optimizer.
@@ -28,6 +28,9 @@ def __init__(
2828 self .validate_range (alpha , 'alpha' , 0.0 , 1.0 )
2929 self .validate_options (pullback_momentum , 'pullback_momentum' , ['none' , 'reset' , 'pullback' ])
3030
31+ self ._optimizer_step_pre_hooks : Dict [int , Callable ] = {}
32+ self ._optimizer_step_post_hooks : Dict [int , Callable ] = {}
33+
3134 self .alpha = alpha
3235 self .k = k
3336 self .pullback_momentum = pullback_momentum
@@ -47,15 +50,13 @@ def __init__(
4750 state ['slow_params' ].copy_ (p )
4851 if self .pullback_momentum == 'pullback' :
4952 state ['slow_momentum' ] = torch .zeros_like (p )
50-
51- # Instead of calling super().__init__, we set the attributes ourselves
52- self ._optimizer_step_pre_hooks : Dict [int , Callable ] = OrderedDict ()
53- self ._optimizer_step_post_hooks : Dict [int , Callable ] = OrderedDict ()
53+
5454 self .defaults : DEFAULTS = {
55+ 'lookahead_alpha' : alpha ,
56+ 'lookahead_k' : k ,
57+ 'lookahead_pullback_momentum' : pullback_momentum ,
5558 ** optimizer .defaults ,
56- ** dict (lookahead_alpha = alpha , lookahead_k = k , lookahead_pullback_momentum = pullback_momentum ),
5759 }
58-
5960
6061 def __getstate__ (self ):
6162 return {
0 commit comments