Skip to content

Commit 3c952d0

Browse files
authored
Merge pull request #26 from kozistr/refactor/lint
[Refactor] Add PyLint to CI
2 parents 5113c54 + 55fc044 commit 3c952d0

File tree

22 files changed

+309
-551
lines changed

22 files changed

+309
-551
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,6 @@ jobs:
3333
- name: Lint
3434
run: |
3535
make format
36+
- name: Check Lint
37+
run: |
38+
make check

.pylintrc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ ignore=
1515
ignore-patterns=ner_dependencies_codes
1616

1717
# Use multiple processes to speed up Pylint.
18-
jobs=6
18+
jobs=8
1919

2020
# List of plugins (as comma separated values of python modules names) to load,
2121
# usually to register additional checkers.
@@ -203,7 +203,7 @@ indent-after-paren=4
203203
indent-string=' '
204204

205205
# Maximum number of characters on a single line.
206-
max-line-length=79
206+
max-line-length=119
207207

208208
# Maximum number of lines in a module
209209
max-module-lines=800
@@ -378,7 +378,7 @@ max-bool-expr=5
378378
max-branches=10
379379

380380
# Maximum number of locals for function / method body
381-
max-locals=20
381+
max-locals=30
382382

383383
# Maximum number of parents for a class (see R0901).
384384
max-parents=10

Makefile

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1-
.PHONY: init check format requirements build deploy
1+
.PHONY: init format check build deploy requirements
22

33
init:
44
python3 -m pip install -U pipenv setuptools wheel
55
python3 -m pipenv install --dev
66

7+
format:
8+
isort --profile black -l 119 pytorch_optimizer setup.py lint.py
9+
black -S -l 119 pytorch_optimizer setup.py lint.py
10+
711
check:
8-
isort --check-only --profile black -l 79 pytorch_optimizer setup.py
9-
black -S -l 79 --check pytorch_optimizer setup.py
10-
pylint pytorch_optimizer
12+
isort --check-only --profile black -l 119 pytorch_optimizer setup.py lint.py
13+
black -S -l 119 --check pytorch_optimizer setup.py lint.py
14+
python3 lint.py
1115

1216
build:
1317
python3 setup.py sdist bdist_wheel
@@ -16,10 +20,6 @@ deploy:
1620
python3 -m twine check dist/*
1721
python3 -m twine upload dist/*
1822

19-
format:
20-
isort --profile black -l 79 pytorch_optimizer setup.py
21-
black -S -l 79 pytorch_optimizer setup.py
22-
2323
requirements:
2424
python3 -m pipenv lock -r > requirements.txt
2525
python3 -m pipenv lock -dr > requirements-dev.txt

lint.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from argparse import ArgumentParser, Namespace
2+
3+
from pylint.lint import Run
4+
5+
6+
def get_configuration() -> Namespace:
7+
parser = ArgumentParser(prog='LINT')
8+
parser.add_argument(
9+
'-p',
10+
'--path',
11+
help='path to directory you want to run pylint | ' 'Default: %(default)s | ' 'Type: %(type)s ',
12+
default='pytorch_optimizer',
13+
type=str,
14+
)
15+
parser.add_argument(
16+
'-t',
17+
'--threshold',
18+
help='score threshold to fail pylint runner | ' 'Default: %(default)s | ' 'Type: %(type)s ',
19+
default=9.5,
20+
type=float,
21+
)
22+
23+
return parser.parse_args()
24+
25+
26+
def main():
27+
args: Namespace = get_configuration()
28+
29+
path: str = str(args.path)
30+
threshold: float = float(args.threshold)
31+
print(f'PyLint Starting | path: {path} | threshold: {threshold:.2f}')
32+
33+
results = Run([path], do_exit=False)
34+
35+
final_score: float = results.linter.stats['global_note']
36+
if final_score < threshold:
37+
print(f'PyLint Failed | score: {final_score:.2f} | threshold: {threshold:.2f}')
38+
raise Exception
39+
else:
40+
print(f'PyLint Passed | score: {final_score:.2f} | threshold: {threshold:.2f}')
41+
42+
43+
if __name__ == '__main__':
44+
main()

pytorch_optimizer/adabelief.py

Lines changed: 22 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,7 @@
33
import torch
44
from torch.optim.optimizer import Optimizer
55

6-
from pytorch_optimizer.types import (
7-
BETAS,
8-
CLOSURE,
9-
DEFAULT_PARAMETERS,
10-
LOSS,
11-
PARAMS,
12-
STATE,
13-
)
6+
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS, STATE
147

158

169
class AdaBelief(Optimizer):
@@ -31,60 +24,47 @@ class AdaBelief(Optimizer):
3124

3225
def __init__(
3326
self,
34-
params: PARAMS,
27+
params: PARAMETERS,
3528
lr: float = 1e-3,
3629
betas: BETAS = (0.9, 0.999),
37-
eps: float = 1e-16,
3830
weight_decay: float = 0.0,
3931
n_sma_threshold: int = 5,
40-
amsgrad: bool = False,
4132
weight_decouple: bool = True,
4233
fixed_decay: bool = False,
4334
rectify: bool = True,
4435
degenerated_to_sgd: bool = True,
36+
amsgrad: bool = False,
37+
eps: float = 1e-16,
4538
):
46-
"""AdaBelief optimizer
47-
:param params: PARAMS. iterable of parameters to optimize
48-
or dicts defining parameter groups
39+
"""
40+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
4941
:param lr: float. learning rate
50-
:param betas: BETAS. coefficients used for computing running averages
51-
of gradient and the squared hessian trace
52-
:param eps: float. term added to the denominator
53-
to improve numerical stability
42+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
5443
:param weight_decay: float. weight decay (L2 penalty)
5544
:param n_sma_threshold: (recommended is 5)
56-
:param amsgrad: bool. whether to use the AMSBound variant
57-
:param weight_decouple: bool. the optimizer uses decoupled weight decay
58-
as in AdamW
45+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW
5946
:param fixed_decay: bool.
6047
:param rectify: bool. perform the rectified update similar to RAdam
61-
:param degenerated_to_sgd: bool. perform SGD update
62-
when variance of gradient is high
48+
:param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high
49+
:param amsgrad: bool. whether to use the AMSBound variant
50+
:param eps: float. term added to the denominator to improve numerical stability
6351
"""
6452
self.lr = lr
6553
self.betas = betas
66-
self.eps = eps
6754
self.weight_decay = weight_decay
6855
self.n_sma_threshold = n_sma_threshold
69-
self.degenerated_to_sgd = degenerated_to_sgd
7056
self.weight_decouple = weight_decouple
71-
self.rectify = rectify
7257
self.fixed_decay = fixed_decay
58+
self.rectify = rectify
7359
self.degenerated_to_sgd = degenerated_to_sgd
60+
self.eps = eps
7461

75-
if (
76-
isinstance(params, (list, tuple))
77-
and len(params) > 0
78-
and isinstance(params[0], dict)
79-
):
62+
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
8063
for param in params:
81-
if 'betas' in param and (
82-
param['betas'][0] != betas[0]
83-
or param['betas'][1] != betas[1]
84-
):
64+
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
8565
param['buffer'] = [[None, None, None] for _ in range(10)]
8666

87-
defaults: DEFAULT_PARAMETERS = dict(
67+
defaults: DEFAULTS = dict(
8868
lr=lr,
8969
betas=betas,
9070
eps=eps,
@@ -129,9 +109,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
129109

130110
grad = p.grad.data
131111
if grad.is_sparse:
132-
raise RuntimeError(
133-
'AdaBelief does not support sparse gradients'
134-
)
112+
raise RuntimeError('AdaBelief does not support sparse gradients')
135113

136114
amsgrad = group['amsgrad']
137115

@@ -163,9 +141,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
163141

164142
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
165143
grad_residual = grad - exp_avg
166-
exp_avg_var.mul_(beta2).addcmul_(
167-
grad_residual, grad_residual, value=1 - beta2
168-
)
144+
exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)
169145

170146
if amsgrad:
171147
max_exp_avg_var = state['max_exp_avg_var']
@@ -176,14 +152,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
176152
out=max_exp_avg_var,
177153
)
178154

179-
denom = (
180-
max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)
181-
).add_(group['eps'])
155+
denom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
182156
else:
183-
denom = (
184-
exp_avg_var.add_(group['eps']).sqrt()
185-
/ math.sqrt(bias_correction2)
186-
).add_(group['eps'])
157+
denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
187158

188159
if not self.rectify:
189160
step_size = group['lr'] / bias_correction1
@@ -196,9 +167,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
196167
buffered[0] = state['step']
197168
beta2_t = beta2 ** state['step']
198169
n_sma_max = 2 / (1 - beta2) - 1
199-
n_sma = n_sma_max - 2 * state['step'] * beta2_t / (
200-
1 - beta2_t
201-
)
170+
n_sma = n_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
202171
buffered[1] = n_sma
203172

204173
if n_sma >= self.n_sma_threshold:
@@ -219,9 +188,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
219188

220189
if n_sma >= self.n_sma_threshold:
221190
denom = exp_avg_var.sqrt().add_(group['eps'])
222-
p.data.addcdiv_(
223-
exp_avg, denom, value=-step_size * group['lr']
224-
)
191+
p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
225192
elif step_size > 0:
226193
p.data.add_(exp_avg, alpha=-step_size * group['lr'])
227194

0 commit comments

Comments
 (0)