22from typing import Dict
33
44import torch
5- from torch .optim import Optimizer
65
76from pytorch_optimizer .base .optimizer import BaseOptimizer
8- from pytorch_optimizer .base .types import CLOSURE , DEFAULTS , LOSS , OPTIMIZER , STATE
7+ from pytorch_optimizer .base .types import CLOSURE , LOSS , OPTIMIZER , STATE
98
109
11- class Lookahead (Optimizer , BaseOptimizer ):
10+ class Lookahead (BaseOptimizer ):
1211 r"""k steps forward, 1 step back.
1312
1413 :param optimizer: OPTIMIZER. base optimizer.
@@ -17,7 +16,7 @@ class Lookahead(Optimizer, BaseOptimizer):
1716 :param pullback_momentum: str. change to inner optimizer momentum on interpolation update.
1817 """
1918
20- def __init__ ( # pylint: disable=super-init-not-called
19+ def __init__ (
2120 self ,
2221 optimizer : OPTIMIZER ,
2322 k : int = 5 ,
@@ -32,62 +31,90 @@ def __init__( # pylint: disable=super-init-not-called
3231 self .validate_parameters ()
3332
3433 self .param_groups = self .optimizer .param_groups
35- self .fast_state : STATE = self .optimizer .state
3634 self .state : STATE = defaultdict (dict )
37- self .reset ()
3835
39- self .defaults : DEFAULTS = optimizer .defaults
40- self .defaults .update (
41- {
42- 'k' : k ,
43- 'alpha' : alpha ,
44- 'pullback_momentum' : pullback_momentum ,
45- }
46- )
36+ for group in self .param_groups :
37+ if 'counter' not in group :
38+ group ['counter' ] = 0
39+
40+ for p in group ['params' ]:
41+ state = self .state [p ]
42+ state ['slow_params' ] = torch .empty_like (p )
43+ state ['slow_params' ].copy_ (p )
44+ if self .pullback_momentum == 'pullback' :
45+ state ['slow_momentum' ] = torch .zeros_like (p )
4746
4847 def validate_parameters (self ):
4948 self .validate_lookahead_k (self .k )
5049 self .validate_alpha (self .alpha )
5150 self .validate_pullback_momentum (self .pullback_momentum )
5251
52+ def __getstate__ (self ):
53+ return {
54+ 'state' : self .state ,
55+ 'optimizer' : self .optimizer ,
56+ 'alpha' : self .alpha ,
57+ 'k' : self .k ,
58+ 'pullback_momentum' : self .pullback_momentum ,
59+ }
60+
5361 @torch .no_grad ()
5462 def reset (self ):
5563 for group in self .param_groups :
5664 group ['counter' ] = 0
5765
66+ def backup_and_load_cache (self ):
67+ r"""Backup cache parameters."""
68+ for group in self .param_groups :
69+ for p in group ['params' ]:
70+ state = self .state [p ]
71+ state ['backup_params' ] = torch .empty_like (p )
72+ state ['backup_params' ].copy_ (p )
73+ p .data .copy_ (state ['slow_params' ])
74+
75+ def clear_and_load_backup (self ):
76+ r"""Load backup parameters."""
77+ for group in self .param_groups :
78+ for p in group ['params' ]:
79+ state = self .state [p ]
80+ p .data .copy_ (state ['backup_params' ])
81+ del state ['backup_params' ]
82+
83+ def state_dict (self ) -> STATE :
84+ return self .optimizer .state_dict ()
85+
86+ def load_state_dict (self , state : STATE ):
87+ r"""Load state."""
88+ self .optimizer .load_state_dict (state )
89+
90+ @torch .no_grad ()
91+ def zero_grad (self ):
92+ self .optimizer .zero_grad (set_to_none = True )
93+
5894 @torch .no_grad ()
5995 def update (self , group : Dict ):
60- for fast in group ['params' ]:
61- if fast .grad is None :
96+ for p in group ['params' ]:
97+ if p .grad is None :
6298 continue
6399
64- param_state = self .state [fast ]
65- if 'slow_param' not in param_state :
66- param_state ['slow_param' ] = torch .empty_like (fast )
67- param_state ['slow_param' ].copy_ (fast )
68- if self .pullback_momentum == 'pullback' :
69- param_state ['slow_mom' ] = torch .zeros_like (fast )
100+ state = self .state [p ]
70101
71- slow = param_state ['slow_param' ]
72- slow .add_ (fast - slow , alpha = self .alpha )
102+ slow = state ['slow_params' ]
73103
74- fast .copy_ (slow )
104+ p .mul_ (self .alpha ).add_ (slow , alpha = 1.0 - self .alpha )
105+ slow .copy_ (p )
75106
76- if 'momentum_buffer' not in self .optimizer .state [fast ]:
77- self .optimizer .state [fast ]['momentum_buffer' ] = torch .zeros_like (fast )
107+ if 'momentum_buffer' not in self .optimizer .state [p ]:
108+ self .optimizer .state [p ]['momentum_buffer' ] = torch .zeros_like (p )
78109
79110 if self .pullback_momentum == 'pullback' :
80- internal_momentum = self .optimizer .state [fast ]['momentum_buffer' ]
81- self .optimizer .state [fast ]['momentum_buffer' ] = internal_momentum .mul_ (self .alpha ).add_ (
82- param_state [ 'slow_mom ' ], alpha = 1.0 - self .alpha
111+ internal_momentum = self .optimizer .state [p ]['momentum_buffer' ]
112+ self .optimizer .state [p ]['momentum_buffer' ] = internal_momentum .mul_ (self .alpha ).add_ (
113+ state [ 'slow_momentum ' ], alpha = 1.0 - self .alpha
83114 )
84- param_state [ 'slow_mom ' ] = self .optimizer .state [fast ]['momentum_buffer' ]
115+ state [ 'slow_momentum ' ] = self .optimizer .state [p ]['momentum_buffer' ]
85116 elif self .pullback_momentum == 'reset' :
86- self .optimizer .state [fast ]['momentum_buffer' ] = torch .zeros_like (fast )
87-
88- def update_lookahead (self ):
89- for group in self .param_groups :
90- self .update (group )
117+ self .optimizer .state [p ]['momentum_buffer' ] = torch .zeros_like (p )
91118
92119 def step (self , closure : CLOSURE = None ) -> LOSS :
93120 loss : LOSS = self .optimizer .step (closure )
@@ -97,25 +124,3 @@ def step(self, closure: CLOSURE = None) -> LOSS:
97124 group ['counter' ] = 0
98125 self .update (group )
99126 return loss
100-
101- def state_dict (self ) -> STATE :
102- fast_state : STATE = self .optimizer .state_dict ()
103- slow_state : STATE = {(id (k ) if isinstance (k , torch .Tensor ) else k ): v for k , v in self .state .items ()}
104-
105- return {
106- 'fast_state' : fast_state ['state' ],
107- 'slow_state' : slow_state ,
108- 'param_groups' : fast_state ['param_groups' ],
109- }
110-
111- def load_state_dict (self , state : STATE ):
112- slow_state : STATE = {'state' : state ['slow_state' ], 'param_groups' : state ['param_groups' ]}
113- fast_state : STATE = {'state' : state ['fast_state' ], 'param_groups' : state ['param_groups' ]}
114- super ().load_state_dict (slow_state )
115-
116- self .optimizer .load_state_dict (fast_state )
117- self .fast_state = self .optimizer .state
118-
119- def add_param_group (self , param_group ):
120- param_group ['counter' ] = 0
121- self .optimizer .add_param_group (param_group )
0 commit comments