Skip to content

Commit ce34d10

Browse files
Avoid calling __init__
1 parent cdae3db commit ce34d10

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

pytorch_optimizer/optimizer/lookahead.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from collections import defaultdict
2-
from typing import Dict
1+
from collections import defaultdict, OrderedDict
2+
from typing import Dict, Callable
33

44
import torch
55
from torch.optim import Optimizer
@@ -33,9 +33,7 @@ def __init__(
3333
self.pullback_momentum = pullback_momentum
3434

3535
self.optimizer = optimizer
36-
37-
defaults: DEFAULTS = dict()
38-
super().__init__(optimizer.param_groups, defaults)
36+
self.param_groups = self.optimizer.param_groups
3937

4038
self.state: STATE = defaultdict(dict)
4139

@@ -50,6 +48,14 @@ def __init__(
5048
if self.pullback_momentum == 'pullback':
5149
state['slow_momentum'] = torch.zeros_like(p)
5250

51+
# Instead of calling super().__init__, we set the attributes ourselves
52+
self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()
53+
self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()
54+
self.defaults: DEFAULTS = {
55+
**optimizer.defaults,
56+
**dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_pullback_momentum=pullback_momentum),
57+
}
58+
5359

5460
def __getstate__(self):
5561
return {

0 commit comments

Comments
 (0)