Skip to content

Commit cdae3db

Browse files
Make sure Lookahead is an Optimizer subclass
1 parent ed0ed10 commit cdae3db

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

pytorch_optimizer/optimizer/lookahead.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from typing import Dict
33

44
import torch
5+
from torch.optim import Optimizer
56

67
from pytorch_optimizer.base.optimizer import BaseOptimizer
7-
from pytorch_optimizer.base.types import CLOSURE, LOSS, OPTIMIZER, STATE
8+
from pytorch_optimizer.base.types import CLOSURE, LOSS, OPTIMIZER, STATE, DEFAULTS
89

910

10-
class Lookahead(BaseOptimizer):
11+
class Lookahead(BaseOptimizer, Optimizer):
1112
r"""k steps forward, 1 step back.
1213
1314
:param optimizer: OPTIMIZER. base optimizer.
@@ -32,7 +33,10 @@ def __init__(
3233
self.pullback_momentum = pullback_momentum
3334

3435
self.optimizer = optimizer
35-
self.param_groups = self.optimizer.param_groups
36+
37+
defaults: DEFAULTS = dict()
38+
super().__init__(optimizer.param_groups, defaults)
39+
3640
self.state: STATE = defaultdict(dict)
3741

3842
for group in self.param_groups:
@@ -45,6 +49,7 @@ def __init__(
4549
state['slow_params'].copy_(p)
4650
if self.pullback_momentum == 'pullback':
4751
state['slow_momentum'] = torch.zeros_like(p)
52+
4853

4954
def __getstate__(self):
5055
return {

0 commit comments

Comments
 (0)