Skip to content

Commit 75e1c6d

Browse files
authored
Merge pull request #60 from kozistr/feature/list-optimizers
[Feature] List of supported optimizers
2 parents c6e3159 + b087e56 commit 75e1c6d

File tree

14 files changed

+83
-100
lines changed

14 files changed

+83
-100
lines changed

.github/workflows/ci.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,27 @@ jobs:
1515

1616
steps:
1717
- uses: actions/checkout@v3
18-
- name: set up Python ${{ matrix.python-version }}
18+
- name: Set up Python ${{ matrix.python-version }}
1919
uses: actions/setup-python@v3
2020
with:
2121
python-version: ${{ matrix.python-version }}
22-
- name: cache pip
22+
- name: Cache pip
2323
uses: actions/cache@v3
2424
with:
2525
path: ~/.cache/pip
2626
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
2727
restore-keys: |
2828
${{ runner.os }}-pip-
2929
${{ runner.os }}-
30-
- name: install dependencies
30+
- name: Install dependencies
3131
run: pip install -r requirements-dev.txt
32-
- name: check lint
32+
- name: Check lint
3333
run: make check
34-
- name: check test
34+
- name: Check test
3535
run: |
3636
export PYTHONDONTWRITEBYTECODE=1
3737
make test
38-
- name: check codecov
38+
- name: Check codecov
3939
uses: codecov/codecov-action@v2
4040
with:
4141
token: ${{ secrets.CODECOV_TOKEN }}

.github/workflows/publish.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
steps:
1313
- name: Checkout code
1414
uses: actions/checkout@v3
15-
- name: Create Release
15+
- name: Create release
1616
id: create_release
1717
uses: actions/create-release@v1
1818
env:
@@ -32,7 +32,7 @@ jobs:
3232
uses: actions/setup-python@v3
3333
with:
3434
python-version: 3.9
35-
- name: cache pip
35+
- name: Cache pip
3636
uses: actions/cache@v3
3737
with:
3838
path: ~/.cache/pip

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ score=yes
9696
[REFACTORING]
9797

9898
# Maximum number of nested blocks for function / method body
99-
max-nested-blocks=5
99+
max-nested-blocks=6
100100

101101

102102
[BASIC]

README.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ or you can use optimizer loader, simply passing a name of the optimizer.
3838

3939
::
4040

41-
from pytorch_optimizer import load_optimizers
41+
from pytorch_optimizer import load_optimizer
4242

4343
...
4444
model = YourModel()
45-
opt = load_optimizers(optimizer='adamp')
45+
opt = load_optimizer(optimizer='adamp')
4646
optimizer = opt(model.parameters())
4747
...
4848

lint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def get_configuration() -> Namespace:
1414
parser.add_argument(
1515
'-t',
1616
'--threshold',
17-
default=9.98,
17+
default=10.0,
1818
type=float,
1919
)
2020

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "1.1.0"
3+
version = "1.1.1"
44
description = "Bunch of optimizer implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]

pytorch_optimizer/__init__.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
# pylint: disable=unused-import
2+
from typing import Dict, List
3+
4+
from torch.optim import Optimizer
5+
26
from pytorch_optimizer.adabelief import AdaBelief
37
from pytorch_optimizer.adabound import AdaBound
48
from pytorch_optimizer.adamp import AdamP
@@ -14,7 +18,6 @@
1418
from pytorch_optimizer.lookahead import Lookahead
1519
from pytorch_optimizer.madgrad import MADGRAD
1620
from pytorch_optimizer.nero import Nero
17-
from pytorch_optimizer.optimizers import load_optimizers
1821
from pytorch_optimizer.pcgrad import PCGrad
1922
from pytorch_optimizer.pnm import PNM
2023
from pytorch_optimizer.radam import RAdam
@@ -31,3 +34,37 @@
3134
normalize_gradient,
3235
unit_norm,
3336
)
37+
38+
OPTIMIZER_LIST: List = [
39+
AdaBelief,
40+
AdaBound,
41+
AdamP,
42+
AdaPNM,
43+
DiffGrad,
44+
DiffRGrad,
45+
Lamb,
46+
LARS,
47+
MADGRAD,
48+
Nero,
49+
PNM,
50+
RAdam,
51+
RaLamb,
52+
Ranger,
53+
Ranger21,
54+
SGDP,
55+
Shampoo,
56+
]
57+
OPTIMIZERS: Dict[str, Optimizer] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
58+
59+
60+
def load_optimizer(optimizer: str) -> OPTIMIZERS: # pylint: disable=R0911
61+
optimizer: str = optimizer.lower()
62+
63+
if optimizer not in OPTIMIZERS:
64+
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')
65+
66+
return OPTIMIZERS[optimizer]
67+
68+
69+
def get_supported_optimizers() -> List:
70+
return OPTIMIZER_LIST

pytorch_optimizer/fp16.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def decrease_loss_scale(self):
8787

8888

8989
class SafeFP16Optimizer(Optimizer):
90-
def __init__(self, optimizer, aggregate_g_norms: bool = False):
90+
def __init__(self, optimizer, aggregate_g_norms: bool = False): # pylint: disable=super-init-not-called
9191
self.optimizer = optimizer
9292
self.aggregate_g_norms = aggregate_g_norms
9393

pytorch_optimizer/lookahead.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class Lookahead(Optimizer, BaseOptimizer):
2525
optimizer.step()
2626
"""
2727

28-
def __init__(
28+
def __init__( # pylint: disable=super-init-not-called
2929
self,
3030
optimizer: Optimizer,
3131
k: int = 5,

pytorch_optimizer/optimizers.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

0 commit comments

Comments
 (0)