Skip to content

Commit 121a3fc

Browse files
committed
refactor: Lookahead optimizer
1 parent 74d400f commit 121a3fc

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

pytorch_optimizer/lookahead.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
class Lookahead(Optimizer):
1818
"""
19-
Reference : https://github.com/alphadl/lookahead.pytorch/blob/master/lookahead.py
19+
Reference : https://github.com/alphadl/lookahead.pytorch
2020
Example :
2121
from pytorch_optimizer import AdamP, Lookahead
2222
...
@@ -42,7 +42,8 @@ def __init__(
4242
:param optimizer: Optimizer.
4343
:param k: int. number of lookahead steps
4444
:param alpha: float. linear interpolation factor
45-
:param pullback_momentum: str. change to inner optimizer momentum on interpolation update
45+
:param pullback_momentum: str. change to inner optimizer momentum
46+
on interpolation update
4647
"""
4748
self.optimizer = optimizer
4849
self.k = k
@@ -66,7 +67,7 @@ def __init__(
6667
)
6768

6869
def check_valid_parameters(self):
69-
if 1 > self.k:
70+
if self.k < 1:
7071
raise ValueError(f'Invalid k : {self.k}')
7172
if not 0.0 < self.alpha <= 1.0:
7273
raise ValueError(f'Invalid alpha : {self.alpha}')

0 commit comments

Comments
 (0)