Skip to content

Commit 5538b6b

Browse files
authored
Merge pull request #47 from kozistr/feature/madgrad
[Fix] sparse gradient for MADGRAD
2 parents afa33ed + d3fc1dc commit 5538b6b

File tree

9 files changed

+34
-18
lines changed

9 files changed

+34
-18
lines changed

pytorch_optimizer/adamp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ def __init__(
3838
adamd_debias_term: bool = False,
3939
eps: float = 1e-8,
4040
):
41-
"""
41+
"""AdamP optimizer
4242
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
43-
:param lr: float. learning rate.
43+
:param lr: float. learning rate
4444
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
4545
:param weight_decay: float. weight decay (L2 penalty)
4646
:param delta: float. threshold that determines whether a set of parameters is scale invariant or not

pytorch_optimizer/diffgrad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ def __init__(
3131
weight_decay: float = 0.0,
3232
adamd_debias_term: bool = False,
3333
):
34-
"""
34+
"""DiffGrad
3535
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
36-
:param lr: float. learning rate.
36+
:param lr: float. learning rate
3737
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
3838
:param eps: float. term added to the denominator to improve numerical stability
3939
:param weight_decay: float. weight decay (L2 penalty)

pytorch_optimizer/diffrgrad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ def __init__(
3838
):
3939
"""Blend RAdam with DiffGrad
4040
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
41-
:param lr: float. learning rate.
41+
:param lr: float. learning rate
4242
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
4343
:param weight_decay: float. weight decay (L2 penalty)
4444
:param n_sma_threshold: int. (recommended is 5)
45-
:param degenerated_to_sgd: bool..
45+
:param degenerated_to_sgd: bool. degenerated to SGD
4646
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
4747
:param eps: float. term added to the denominator to improve numerical stability
4848
"""

pytorch_optimizer/lookahead.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def __init__(
3131
alpha: float = 0.5,
3232
pullback_momentum: str = 'none',
3333
):
34-
"""
35-
:param optimizer: Optimizer.
34+
"""Lookahead
35+
:param optimizer: Optimizer. base optimizer
3636
:param k: int. number of lookahead steps
3737
:param alpha: float. linear interpolation factor
3838
:param pullback_momentum: str. change to inner optimizer momentum on interpolation update

pytorch_optimizer/sgdp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def __init__(
3636
wd_ratio: float = 0.1,
3737
nesterov: bool = False,
3838
):
39-
"""
39+
"""SGDP optimizer
4040
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
41-
:param lr: float. learning rate.
41+
:param lr: float. learning rate
4242
:param momentum: float. momentum factor
4343
:param dampening: float. dampening for momentum
4444
:param eps: float. term added to the denominator to improve numerical stability

pytorch_optimizer/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__VERSION__ = '0.3.3'
1+
__VERSION__ = '0.3.4'

tests/test_optimizer_parameters.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from pytorch_optimizer import SAM, load_optimizers
5+
from pytorch_optimizer import SAM, Lookahead, load_optimizers
66

77
OPTIMIZER_NAMES: List[str] = [
88
'adamp',
@@ -67,6 +67,17 @@ def test_betas(optimizer_names):
6767
optimizer(None, betas=(0.1, -0.1))
6868

6969

70-
def test_rho():
70+
def test_sam_parameters():
7171
with pytest.raises(ValueError):
7272
SAM(None, load_optimizers('adamp'), rho=-0.1)
73+
74+
75+
def test_lookahead_parameters():
76+
with pytest.raises(ValueError):
77+
Lookahead(load_optimizers('adamp'), k=0)
78+
79+
with pytest.raises(ValueError):
80+
Lookahead(load_optimizers('adamp'), alpha=0)
81+
82+
with pytest.raises(ValueError):
83+
Lookahead(load_optimizers('adamp'), pullback_momentum='asdf')

tests/test_optimizers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,9 @@ def test_f16_optimizers(optimizer_fp16_config):
197197
assert init_loss - 0.01 > loss
198198

199199

200+
@pytest.mark.parametrize('adaptive', (False, True))
200201
@pytest.mark.parametrize('optimizer_sam_config', FP32_OPTIMIZERS, ids=ids)
201-
def test_sam_optimizers(optimizer_sam_config):
202+
def test_sam_optimizers(adaptive, optimizer_sam_config):
202203
torch.manual_seed(42)
203204

204205
x_data, y_data = make_dataset()
@@ -207,7 +208,7 @@ def test_sam_optimizers(optimizer_sam_config):
207208
loss_fn: nn.Module = nn.BCEWithLogitsLoss()
208209

209210
optimizer_class, config, iterations = optimizer_sam_config
210-
optimizer = SAM(model.parameters(), optimizer_class, **config)
211+
optimizer = SAM(model.parameters(), optimizer_class, **config, adaptive=adaptive)
211212

212213
loss: float = np.inf
213214
init_loss: float = np.inf

tests/test_sparse_gradient.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,9 @@ def test_sparse_not_supported(no_sparse_optimizer):
3131
grad = torch.randn(1, 1).to_sparse(1)
3232
param.grad = grad
3333

34-
optimizer = load_optimizers(optimizer=no_sparse_optimizer)([param])
35-
optimizer.zero_grad()
36-
3734
with pytest.raises(RuntimeError):
35+
optimizer = load_optimizers(optimizer=no_sparse_optimizer)([param])
36+
optimizer.zero_grad()
3837
optimizer.step()
3938

4039

@@ -47,3 +46,8 @@ def test_sparse_supported(sparse_optimizer):
4746
optimizer = load_optimizers(optimizer=sparse_optimizer)([param], momentum=0.0)
4847
optimizer.zero_grad()
4948
optimizer.step()
49+
50+
with pytest.raises(RuntimeError):
51+
optimizer = load_optimizers(optimizer=sparse_optimizer)([param], momentum=0.0, weight_decay=1e-3)
52+
optimizer.zero_grad()
53+
optimizer.step()

0 commit comments

Comments
 (0)