11from collections import defaultdict
2- from typing import Callable , Dict , List , Optional
2+ from typing import Dict
33
44import torch
55from torch .optim import Optimizer
66
7+ from pytorch_optimizer .types import (
8+ CLOSURE ,
9+ LOSS ,
10+ PARAM_GROUP ,
11+ PARAM_GROUPS ,
12+ STATE ,
13+ )
14+
715
816class Lookahead (Optimizer ):
917 def __init__ (self , optimizer : Optimizer , k : int = 5 , alpha : float = 0.5 ):
1018 self .optimizer = optimizer
1119 self .k = k
1220 self .alpha = alpha
1321
14- self .param_groups : List [ Dict ] = self .optimizer .param_groups
15- self .fast_state : Dict = self .optimizer .state
16- self .state = defaultdict (dict )
22+ self .param_groups : PARAM_GROUPS = self .optimizer .param_groups
23+ self .fast_state : STATE = self .optimizer .state
24+ self .state : STATE = defaultdict (dict )
1725
1826 for group in self .param_groups :
1927 group ['counter' ] = 0
@@ -32,8 +40,8 @@ def update_lookahead(self):
3240 for group in self .param_groups :
3341 self .update (group )
3442
35- def step (self , closure : Optional [ Callable ] = None ) -> float :
36- loss : float = self .optimizer .step (closure )
43+ def step (self , closure : CLOSURE = None ) -> LOSS :
44+ loss : LOSS = self .optimizer .step (closure )
3745 for group in self .param_groups :
3846 if group ['counter' ] == 0 :
3947 self .update (group )
@@ -42,12 +50,12 @@ def step(self, closure: Optional[Callable] = None) -> float:
4250 group ['counter' ] = 0
4351 return loss
4452
45- def state_dict (self ) -> Dict [ str , torch . Tensor ] :
46- fast_state_dict = self .optimizer .state_dict ()
53+ def state_dict (self ) -> STATE :
54+ fast_state_dict : STATE = self .optimizer .state_dict ()
4755 fast_state = fast_state_dict ['state' ]
4856 param_groups = fast_state_dict ['param_groups' ]
4957
50- slow_state : Dict [ int , torch . Tensor ] = {
58+ slow_state : STATE = {
5159 (id (k ) if isinstance (k , torch .Tensor ) else k ): v
5260 for k , v in self .state .items ()
5361 }
@@ -58,12 +66,12 @@ def state_dict(self) -> Dict[str, torch.Tensor]:
5866 'param_groups' : param_groups ,
5967 }
6068
61- def load_state_dict (self , state_dict : Dict [ str , torch . Tensor ] ):
62- slow_state_dict : Dict [ str , torch . Tensor ] = {
69+ def load_state_dict (self , state_dict : STATE ):
70+ slow_state_dict : STATE = {
6371 'state' : state_dict ['slow_state' ],
6472 'param_groups' : state_dict ['param_groups' ],
6573 }
66- fast_state_dict : Dict [ str , torch . Tensor ] = {
74+ fast_state_dict : STATE = {
6775 'state' : state_dict ['fast_state' ],
6876 'param_groups' : state_dict ['param_groups' ],
6977 }
@@ -72,6 +80,6 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
7280 self .optimizer .load_state_dict (fast_state_dict )
7381 self .fast_state = self .optimizer .state
7482
75- def add_param_group (self , param_group : Dict ):
83+ def add_param_group (self , param_group : PARAM_GROUP ):
7684 param_group ['counter' ] = 0
7785 self .optimizer .add_param_group (param_group )
0 commit comments