Skip to content

Commit dc3c356

Browse files
authored
Merge pull request #45 from kozistr/test/cases
[Test] Add more test cases
2 parents 5f1ef59 + 033a842 commit dc3c356

File tree

8 files changed

+167
-29
lines changed

8 files changed

+167
-29
lines changed

pytorch_optimizer/adabelief.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ def __init__(
3838
adamd_debias_term: bool = False,
3939
eps: float = 1e-16,
4040
):
41-
"""
41+
"""AdaBelief
4242
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
4343
:param lr: float. learning rate
4444
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
4545
:param weight_decay: float. weight decay (L2 penalty)
4646
:param n_sma_threshold: (recommended is 5)
4747
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW
48-
:param fixed_decay: bool.
48+
:param fixed_decay: bool. fix weight decay
4949
:param rectify: bool. perform the rectified update similar to RAdam
5050
:param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high
5151
:param amsgrad: bool. whether to use the AMSBound variant
@@ -63,6 +63,8 @@ def __init__(
6363
self.adamd_debias_term = adamd_debias_term
6464
self.eps = eps
6565

66+
self.check_valid_parameters()
67+
6668
buffer: BUFFER = [[None, None, None] for _ in range(10)]
6769

6870
if is_valid_parameters(params):
@@ -81,6 +83,18 @@ def __init__(
8183
)
8284
super().__init__(params, defaults)
8385

86+
def check_valid_parameters(self):
87+
if self.lr < 0.0:
88+
raise ValueError(f'Invalid learning rate : {self.lr}')
89+
if not 0.0 <= self.betas[0] < 1.0:
90+
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
91+
if not 0.0 <= self.betas[1] < 1.0:
92+
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
93+
if self.weight_decay < 0.0:
94+
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
95+
if self.eps < 0.0:
96+
raise ValueError(f'Invalid eps : {self.eps}')
97+
8498
def __setstate__(self, state: STATE):
8599
super().__setstate__(state)
86100
for group in self.param_groups:

pytorch_optimizer/adabound.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,15 @@ def __init__(
3737
adamd_debias_term: bool = False,
3838
eps: float = 1e-8,
3939
):
40-
"""
40+
"""AdaBound
4141
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
4242
:param lr: float. learning rate
4343
:param final_lr: float. final learning rate
4444
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
4545
:param gamma: float. convergence speed of the bound functions
4646
:param weight_decay: float. weight decay (L2 penalty)
4747
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW
48-
:param fixed_decay: bool.
48+
:param fixed_decay: bool. fix weight decay
4949
:param amsbound: bool. whether to use the AMSBound variant
5050
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
5151
:param eps: float. term added to the denominator to improve numerical stability
@@ -57,6 +57,8 @@ def __init__(
5757
self.fixed_decay = fixed_decay
5858
self.eps = eps
5959

60+
self.check_valid_parameters()
61+
6062
defaults: DEFAULTS = dict(
6163
lr=lr,
6264
betas=betas,

pytorch_optimizer/pcgrad.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import random
22
from copy import deepcopy
3-
from typing import Iterable, List
3+
from typing import Iterable, List, Tuple
44

55
import numpy as np
66
import torch
@@ -35,12 +35,12 @@ def check_valid_parameters(self):
3535
raise ValueError(f'invalid reduction : {self.reduction}')
3636

3737
@staticmethod
38-
def flatten_grad(grads) -> torch.Tensor:
38+
def flatten_grad(grads: List[torch.Tensor]) -> torch.Tensor:
3939
return torch.cat([g.flatten() for g in grads])
4040

4141
@staticmethod
4242
def un_flatten_grad(grads, shapes) -> List[torch.Tensor]:
43-
un_flatten_grad = []
43+
un_flatten_grad: List[torch.Tensor] = []
4444
idx: int = 0
4545
for shape in shapes:
4646
length = np.prod(shape)
@@ -54,39 +54,40 @@ def zero_grad(self):
5454
def step(self):
5555
return self.optimizer.step()
5656

57-
def set_grad(self, grads):
57+
def set_grad(self, grads: List[torch.Tensor]):
5858
idx: int = 0
5959
for group in self.optimizer.param_groups:
6060
for p in group['params']:
6161
p.grad = grads[idx]
6262
idx += 1
6363

64-
def retrieve_grad(self):
64+
def retrieve_grad(self) -> Tuple[List[torch.Tensor], List[int], List[torch.Tensor]]:
6565
"""get the gradient of the parameters of the network with specific objective"""
6666
grad, shape, has_grad = [], [], []
6767
for group in self.optimizer.param_groups:
6868
for p in group['params']:
6969
if p.grad is None:
7070
shape.append(p.shape)
71-
grad.append(torch.zeros_like(p).to(p.device))
72-
has_grad.append(torch.zeros_like(p).to(p.device))
71+
grad.append(torch.zeros_like(p, device=p.device))
72+
has_grad.append(torch.zeros_like(p, device=p.device))
7373
continue
7474

7575
shape.append(p.grad.shape)
7676
grad.append(p.grad.clone())
77-
has_grad.append(torch.ones_like(p).to(p.device))
77+
has_grad.append(torch.ones_like(p, device=p.device))
7878

7979
return grad, shape, has_grad
8080

81-
def pack_grad(self, objectives: Iterable[nn.Module]):
81+
def pack_grad(
82+
self, objectives: Iterable[nn.Module]
83+
) -> Tuple[List[torch.Tensor], List[List[int]], List[torch.Tensor]]:
8284
"""pack the gradient of the parameters of the network for each objective
83-
:param objectives: Iterable[float]. a list of objectives
85+
:param objectives: Iterable[nn.Module]. a list of objectives
8486
:return:
8587
"""
8688
grads, shapes, has_grads = [], [], []
8789
for objective in objectives:
88-
self.zero_grad()
89-
90+
self.optimizer.zero_grad(set_to_none=True)
9091
objective.backward(retain_graph=True)
9192

9293
grad, shape, has_grad = self.retrieve_grad()
@@ -98,7 +99,7 @@ def pack_grad(self, objectives: Iterable[nn.Module]):
9899
return grads, shapes, has_grads
99100

100101
def project_conflicting(self, grads, has_grads) -> torch.Tensor:
101-
"""
102+
"""project conflicting
102103
:param grads: a list of the gradient of the parameters
103104
:param has_grads: a list of mask represent whether the parameter has gradient
104105
:return:
@@ -114,12 +115,10 @@ def project_conflicting(self, grads, has_grads) -> torch.Tensor:
114115
g_i -= g_i_g_j * g_j / (g_j.norm() ** 2)
115116

116117
merged_grad = torch.zeros_like(grads[0]).to(grads[0].device)
117-
merged_grad[shared] = torch.stack([g[shared] for g in pc_grad])
118-
119118
if self.reduction == 'mean':
120-
merged_grad = merged_grad.mean(dim=0)
121-
else: # self.reduction == 'sum'
122-
merged_grad = merged_grad.sum(dim=0)
119+
merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).mean(dim=0)
120+
else:
121+
merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).sum(dim=0)
123122

124123
merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0)
125124

pytorch_optimizer/radam.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ def __init__(
3535
adamd_debias_term: bool = False,
3636
eps: float = 1e-8,
3737
):
38-
"""
38+
"""RAdam
3939
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
40-
:param lr: float. learning rate.
40+
:param lr: float. learning rate
4141
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
4242
:param weight_decay: float. weight decay (L2 penalty)
4343
:param n_sma_threshold: int. (recommended is 5)
44-
:param degenerated_to_sgd: float.
44+
:param degenerated_to_sgd: float. degenerated to SGD
4545
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
4646
:param eps: float. term added to the denominator to improve numerical stability
4747
"""

pytorch_optimizer/ranger21.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def __init__(
9292
self.norm_loss_factor = norm_loss_factor
9393
self.eps = eps
9494

95+
self.check_valid_parameters()
96+
9597
# lookahead
9698
self.lookahead_step: int = 0
9799

@@ -124,6 +126,18 @@ def __init__(
124126
self.start_warm_down: int = num_iterations - self.num_warm_down_iterations
125127
self.warm_down_lr_delta: float = self.starting_lr - self.min_lr
126128

129+
def check_valid_parameters(self):
130+
if self.lr < 0.0:
131+
raise ValueError(f'Invalid learning rate : {self.lr}')
132+
if not 0.0 <= self.betas[0] < 1.0:
133+
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
134+
if not 0.0 <= self.betas[1] < 1.0:
135+
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
136+
if self.weight_decay < 0.0:
137+
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
138+
if self.eps < 0.0:
139+
raise ValueError(f'Invalid eps : {self.eps}')
140+
127141
def __setstate__(self, state: STATE):
128142
super().__setstate__(state)
129143

tests/test_load_optimizers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,5 @@ def test_load_optimizers_valid(valid_optimizer_names):
3636

3737
@pytest.mark.parametrize('invalid_optimizer_names', INVALID_OPTIMIZER_NAMES)
3838
def test_load_optimizers_invalid(invalid_optimizer_names):
39-
try:
39+
with pytest.raises(NotImplementedError):
4040
load_optimizers(invalid_optimizer_names)
41-
except NotImplementedError:
42-
return True
43-
return False

tests/test_optimizer_parameters.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from typing import List
2+
3+
import pytest
4+
5+
from pytorch_optimizer import load_optimizers
6+
7+
OPTIMIZER_NAMES: List[str] = [
8+
'adamp',
9+
'sgdp',
10+
'madgrad',
11+
'ranger',
12+
'ranger21',
13+
'radam',
14+
'adabound',
15+
'adahessian',
16+
'adabelief',
17+
'diffgrad',
18+
'diffrgrad',
19+
'lamb',
20+
]
21+
22+
BETA_OPTIMIZER_NAMES: List[str] = [
23+
'adabelief',
24+
'adabound',
25+
'adahessian',
26+
'adamp',
27+
'diffgrad',
28+
'diffrgrad',
29+
'lamb',
30+
'radam',
31+
'ranger',
32+
'ranger21',
33+
]
34+
35+
36+
@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES)
37+
def test_learning_rate(optimizer_names):
38+
with pytest.raises(ValueError):
39+
optimizer = load_optimizers(optimizer_names)
40+
optimizer(None, lr=-1e-2)
41+
42+
43+
@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES)
44+
def test_epsilon(optimizer_names):
45+
with pytest.raises(ValueError):
46+
optimizer = load_optimizers(optimizer_names)
47+
optimizer(None, eps=-1e-6)
48+
49+
50+
@pytest.mark.parametrize('optimizer_names', OPTIMIZER_NAMES)
51+
def test_weight_decay(optimizer_names):
52+
with pytest.raises(ValueError):
53+
optimizer = load_optimizers(optimizer_names)
54+
optimizer(None, weight_decay=-1e-3)
55+
56+
57+
@pytest.mark.parametrize('optimizer_names', BETA_OPTIMIZER_NAMES)
58+
def test_betas(optimizer_names):
59+
with pytest.raises(ValueError):
60+
optimizer = load_optimizers(optimizer_names)
61+
optimizer(None, betas=(-0.1, 0.1))
62+
63+
with pytest.raises(ValueError):
64+
optimizer = load_optimizers(optimizer_names)
65+
optimizer(None, betas=(0.1, -0.1))

tests/test_optimizers.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
DiffRGrad,
1818
Lamb,
1919
Lookahead,
20+
PCGrad,
2021
RAdam,
2122
Ranger,
2223
Ranger21,
@@ -39,6 +40,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
3940
return x
4041

4142

43+
class MultiHeadLogisticRegression(nn.Module):
44+
def __init__(self):
45+
super().__init__()
46+
self.fc1 = nn.Linear(2, 2)
47+
self.head1 = nn.Linear(2, 1)
48+
self.head2 = nn.Linear(2, 1)
49+
50+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
51+
x = self.fc1(x)
52+
x = F.relu(x)
53+
return self.head1(x), self.head2(x)
54+
55+
4256
def make_dataset(num_samples: int = 100, dims: int = 2, seed: int = 42) -> Tuple[torch.Tensor, torch.Tensor]:
4357
rng = np.random.RandomState(seed)
4458

@@ -180,4 +194,37 @@ def test_sam_optimizers(optimizer_config):
180194
loss_fn(y_data, model(x_data)).backward()
181195
optimizer.second_step(zero_grad=True)
182196

197+
if init_loss == np.inf:
198+
init_loss = loss
199+
200+
assert init_loss > 2.0 * loss
201+
202+
203+
@pytest.mark.parametrize('optimizer_config', FP32_OPTIMIZERS, ids=ids)
204+
def test_pc_grad_optimizers(optimizer_config):
205+
torch.manual_seed(42)
206+
207+
x_data, y_data = make_dataset()
208+
209+
model: nn.Module = MultiHeadLogisticRegression()
210+
loss_fn_1: nn.Module = nn.BCEWithLogitsLoss()
211+
loss_fn_2: nn.Module = nn.L1Loss()
212+
213+
optimizer_class, config, iterations = optimizer_config
214+
optimizer = PCGrad(optimizer_class(model.parameters(), **config))
215+
216+
loss: float = np.inf
217+
init_loss: float = np.inf
218+
for _ in range(iterations):
219+
optimizer.zero_grad()
220+
y_pred_1, y_pred_2 = model(x_data)
221+
loss1, loss2 = loss_fn_1(y_pred_1, y_data), loss_fn_2(y_pred_2, y_data)
222+
223+
loss = (loss1 + loss2) / 2.0
224+
if init_loss == np.inf:
225+
init_loss = loss
226+
227+
optimizer.pc_backward([loss1, loss2])
228+
optimizer.step()
229+
183230
assert init_loss > 2.0 * loss

0 commit comments

Comments
 (0)