Skip to content

Commit c53befe

Browse files
committed
refactor: Lookahead
1 parent ba85681 commit c53befe

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

pytorch_optimizer/optimizer/lookahead.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def update(self, group: Dict):
6969
param_state['slow_mom'] = torch.zeros_like(fast)
7070

7171
slow = param_state['slow_param']
72-
slow += (fast - slow) * self.alpha
72+
slow.add_(fast - slow, alpha=self.alpha)
73+
7374
fast.copy_(slow)
7475

7576
if 'momentum_buffer' not in self.optimizer.state[fast]:
@@ -98,30 +99,21 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9899
return loss
99100

100101
def state_dict(self) -> STATE:
101-
fast_state_dict: STATE = self.optimizer.state_dict()
102-
fast_state = fast_state_dict['state']
103-
param_groups = fast_state_dict['param_groups']
104-
102+
fast_state: STATE = self.optimizer.state_dict()
105103
slow_state: STATE = {(id(k) if isinstance(k, torch.Tensor) else k): v for k, v in self.state.items()}
106104

107105
return {
108-
'fast_state': fast_state,
106+
'fast_state': fast_state['state'],
109107
'slow_state': slow_state,
110-
'param_groups': param_groups,
108+
'param_groups': fast_state['param_groups'],
111109
}
112110

113-
def load_state_dict(self, state_dict: STATE):
114-
slow_state_dict: STATE = {
115-
'state': state_dict['slow_state'],
116-
'param_groups': state_dict['param_groups'],
117-
}
118-
fast_state_dict: STATE = {
119-
'state': state_dict['fast_state'],
120-
'param_groups': state_dict['param_groups'],
121-
}
122-
super().load_state_dict(slow_state_dict)
111+
def load_state_dict(self, state: STATE):
112+
slow_state: STATE = {'state': state['slow_state'], 'param_groups': state['param_groups']}
113+
fast_state: STATE = {'state': state['fast_state'], 'param_groups': state['param_groups']}
114+
super().load_state_dict(slow_state)
123115

124-
self.optimizer.load_state_dict(fast_state_dict)
116+
self.optimizer.load_state_dict(fast_state)
125117
self.fast_state = self.optimizer.state
126118

127119
def add_param_group(self, param_group):

0 commit comments

Comments
 (0)