Skip to content

Commit be8e0bc

Browse files
committed
refactor: Lookahead optimizer
1 parent 1e502c2 commit be8e0bc

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

pytorch_optimizer/optimizer/lookahead.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
from collections import defaultdict, OrderedDict
2-
from typing import Dict, Callable
1+
from collections import defaultdict
2+
from typing import Callable, Dict
33

44
import torch
55
from torch.optim import Optimizer
66

77
from pytorch_optimizer.base.optimizer import BaseOptimizer
8-
from pytorch_optimizer.base.types import CLOSURE, LOSS, OPTIMIZER, STATE, DEFAULTS
8+
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER, STATE
99

1010

11-
class Lookahead(BaseOptimizer, Optimizer):
11+
class Lookahead(Optimizer, BaseOptimizer):
1212
r"""k steps forward, 1 step back.
1313
1414
:param optimizer: OPTIMIZER. base optimizer.
@@ -28,6 +28,9 @@ def __init__(
2828
self.validate_range(alpha, 'alpha', 0.0, 1.0)
2929
self.validate_options(pullback_momentum, 'pullback_momentum', ['none', 'reset', 'pullback'])
3030

31+
self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
32+
self._optimizer_step_post_hooks: Dict[int, Callable] = {}
33+
3134
self.alpha = alpha
3235
self.k = k
3336
self.pullback_momentum = pullback_momentum
@@ -47,15 +50,13 @@ def __init__(
4750
state['slow_params'].copy_(p)
4851
if self.pullback_momentum == 'pullback':
4952
state['slow_momentum'] = torch.zeros_like(p)
50-
51-
# Instead of calling super().__init__, we set the attributes ourselves
52-
self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()
53-
self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()
53+
5454
self.defaults: DEFAULTS = {
55+
'lookahead_alpha': alpha,
56+
'lookahead_k': k,
57+
'lookahead_pullback_momentum': pullback_momentum,
5558
**optimizer.defaults,
56-
**dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_pullback_momentum=pullback_momentum),
5759
}
58-
5960

6061
def __getstate__(self):
6162
return {

0 commit comments

Comments
 (0)