Skip to content

Commit 0a23375

Browse files
authored
Merge pull request #202 from kozistr/fix/lookahead-optimizer
[Refactor] Lookahead optimizer
2 parents ca8566e + 3bee1c5 commit 0a23375

File tree

8 files changed

+69
-49
lines changed

8 files changed

+69
-49
lines changed

CONTRIBUTING.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ Contributions to `pytorch-optimizer` for code, documentation, and tests are alwa
77
Currently, `black` and `ruff` are used to format & lint the code. Here are the [lint options](https://github.com/kozistr/pytorch_optimizer/blob/main/pyproject.toml#L69)
88
Or you just simply run `make format` and `make check` on the project root.
99

10+
You can create the environment with `make init` or just install the pip packages to your computer.
11+
1012
A few differences from the default `black` (or another style guide) are
1113

1214
1. line-length is **119** characters.

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
.PHONY: init format test check requirements docs
22

33
init:
4-
python -m pip install -q -U poetry
5-
python -m poetry install
4+
python -m pip install -q -U poetry isort black ruff pytest pytest-cov
5+
python -m poetry install --dev
66

77
format:
88
isort --profile black -l 119 pytorch_optimizer tests hubconf.py

docs/changelogs/v2.11.2.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
## Change Log
2+
3+
### Fix
4+
5+
* Fix Lookahead optimizer (#200, #201, #202)
6+
* When using PyTorch Lightning which expects your optimiser to be a subclass of `Optimizer`.
7+
8+
### Contributions
9+
10+
thanks to @georg-wolflein
11+
12+
### Diff
13+
14+
[2.11.1...2.11.2](https://github.com/kozistr/pytorch_optimizer/compare/v2.11.1...v2.11.2)

poetry.lock

Lines changed: 34 additions & 31 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ black = [
6060
{ version = "==23.3.0", python = ">=3.7,<3.8" },
6161
{ version = "^23.7.0", python = ">=3.8"}
6262
]
63-
ruff = "^0.0.284"
63+
ruff = "^0.0.286"
6464
pytest = "^7.4.0"
6565
pytest-cov = "^4.1.0"
6666

pytorch_optimizer/optimizer/lookahead.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
from collections import defaultdict, OrderedDict
2-
from typing import Dict, Callable
1+
from collections import defaultdict
2+
from typing import Callable, Dict
33

44
import torch
55
from torch.optim import Optimizer
66

77
from pytorch_optimizer.base.optimizer import BaseOptimizer
8-
from pytorch_optimizer.base.types import CLOSURE, LOSS, OPTIMIZER, STATE, DEFAULTS
8+
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER, STATE
99

1010

11-
class Lookahead(BaseOptimizer, Optimizer):
11+
class Lookahead(Optimizer, BaseOptimizer):
1212
r"""k steps forward, 1 step back.
1313
1414
:param optimizer: OPTIMIZER. base optimizer.
@@ -28,6 +28,9 @@ def __init__(
2828
self.validate_range(alpha, 'alpha', 0.0, 1.0)
2929
self.validate_options(pullback_momentum, 'pullback_momentum', ['none', 'reset', 'pullback'])
3030

31+
self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
32+
self._optimizer_step_post_hooks: Dict[int, Callable] = {}
33+
3134
self.alpha = alpha
3235
self.k = k
3336
self.pullback_momentum = pullback_momentum
@@ -47,15 +50,13 @@ def __init__(
4750
state['slow_params'].copy_(p)
4851
if self.pullback_momentum == 'pullback':
4952
state['slow_momentum'] = torch.zeros_like(p)
50-
51-
# Instead of calling super().__init__, we set the attributes ourselves
52-
self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()
53-
self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()
53+
5454
self.defaults: DEFAULTS = {
55+
'lookahead_alpha': alpha,
56+
'lookahead_k': k,
57+
'lookahead_pullback_momentum': pullback_momentum,
5558
**optimizer.defaults,
56-
**dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_pullback_momentum=pullback_momentum),
5759
}
58-
5960

6061
def __getstate__(self):
6162
return {

requirements-dev.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
black==23.3.0 ; python_full_version >= "3.7.2" and python_version < "3.8"
44
black==23.7.0 ; python_version >= "3.8" and python_full_version < "4.0.0"
5-
click==8.1.6 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
5+
click==8.1.7 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
66
colorama==0.4.6 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0" and (sys_platform == "win32" or platform_system == "Windows")
77
coverage[toml]==7.2.7 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
8-
exceptiongroup==1.1.2 ; python_full_version >= "3.7.2" and python_version < "3.11"
9-
filelock==3.12.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
8+
exceptiongroup==1.1.3 ; python_full_version >= "3.7.2" and python_version < "3.11"
9+
filelock==3.12.3 ; python_version >= "3.8" and python_full_version < "4.0.0"
1010
importlib-metadata==6.7.0 ; python_full_version >= "3.7.2" and python_version < "3.8"
1111
iniconfig==2.0.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
1212
isort==5.11.5 ; python_full_version >= "3.7.2" and python_version < "3.8"
@@ -24,7 +24,7 @@ platformdirs==3.10.0 ; python_full_version >= "3.7.2" and python_full_version <
2424
pluggy==1.2.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
2525
pytest-cov==4.1.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
2626
pytest==7.4.0 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
27-
ruff==0.0.284 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
27+
ruff==0.0.286 ; python_full_version >= "3.7.2" and python_full_version < "4.0.0"
2828
sympy==1.12 ; python_version >= "3.8" and python_full_version < "4.0.0"
2929
tomli==2.0.1 ; python_full_version >= "3.7.2" and python_version < "3.11"
3030
torch==1.13.1+cpu ; python_full_version >= "3.7.2" and python_version < "3.8"

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
--extra-index-url https://download.pytorch.org/whl/cpu
22

3-
filelock==3.12.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
3+
filelock==3.12.3 ; python_version >= "3.8" and python_full_version < "4.0.0"
44
jinja2==3.1.2 ; python_version >= "3.8" and python_full_version < "4.0.0"
55
markupsafe==2.1.3 ; python_version >= "3.8" and python_full_version < "4.0.0"
66
mpmath==1.3.0 ; python_version >= "3.8" and python_full_version < "4.0.0"

0 commit comments

Comments
 (0)