Skip to content

Commit ca8566e

Browse files
authored
Merge pull request #201 from georg-wolflein/main
[Fix] Make Lookahead a subclass of Optimizer
2 parents ed0ed10 + ce34d10 commit ca8566e

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

pytorch_optimizer/optimizer/lookahead.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
from collections import defaultdict
2-
from typing import Dict
1+
from collections import defaultdict, OrderedDict
2+
from typing import Dict, Callable
33

44
import torch
5+
from torch.optim import Optimizer
56

67
from pytorch_optimizer.base.optimizer import BaseOptimizer
7-
from pytorch_optimizer.base.types import CLOSURE, LOSS, OPTIMIZER, STATE
8+
from pytorch_optimizer.base.types import CLOSURE, LOSS, OPTIMIZER, STATE, DEFAULTS
89

910

10-
class Lookahead(BaseOptimizer):
11+
class Lookahead(BaseOptimizer, Optimizer):
1112
r"""k steps forward, 1 step back.
1213
1314
:param optimizer: OPTIMIZER. base optimizer.
@@ -33,6 +34,7 @@ def __init__(
3334

3435
self.optimizer = optimizer
3536
self.param_groups = self.optimizer.param_groups
37+
3638
self.state: STATE = defaultdict(dict)
3739

3840
for group in self.param_groups:
@@ -45,6 +47,15 @@ def __init__(
4547
state['slow_params'].copy_(p)
4648
if self.pullback_momentum == 'pullback':
4749
state['slow_momentum'] = torch.zeros_like(p)
50+
51+
# Instead of calling super().__init__, we set the attributes ourselves
52+
self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()
53+
self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()
54+
self.defaults: DEFAULTS = {
55+
**optimizer.defaults,
56+
**dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_pullback_momentum=pullback_momentum),
57+
}
58+
4859

4960
def __getstate__(self):
5061
return {

0 commit comments

Comments
 (0)