Skip to content

Commit e74a8d8

Browse files
committed
update: fp16, SAM test cases
1 parent 255d2ed commit e74a8d8

File tree

1 file changed

+75
-4
lines changed

1 file changed

+75
-4
lines changed

tests/test_optimizers.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from pytorch_optimizer import (
1010
MADGRAD,
11+
SAM,
1112
SGDP,
1213
AdaBelief,
1314
AdaBound,
@@ -19,6 +20,7 @@
1920
RAdam,
2021
Ranger,
2122
Ranger21,
23+
SafeFP16Optimizer,
2224
)
2325

2426
__REFERENCE__ = 'https://github.com/jettify/pytorch-optimizer/blob/master/tests/test_optimizer_with_nn.py'
@@ -66,7 +68,7 @@ def build_lookahead(*parameters, **kwargs):
6668
return Lookahead(AdamP(*parameters, **kwargs))
6769

6870

69-
OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
71+
FP32_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
7072
(build_lookahead, {'lr': 1e-2, 'weight_decay': 1e-3}, 200),
7173
(AdaBelief, {'lr': 1e-2, 'weight_decay': 1e-3}, 200),
7274
(AdaBound, {'lr': 1e-2, 'gamma': 0.1, 'weight_decay': 1e-3}, 200),
@@ -81,17 +83,31 @@ def build_lookahead(*parameters, **kwargs):
8183
(Ranger21, {'lr': 5e-1, 'weight_decay': 1e-3, 'num_iterations': 500}, 500),
8284
]
8385

86+
FP16_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
87+
(build_lookahead, {'lr': 5e-1, 'weight_decay': 1e-3}, 500),
88+
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3}, 200),
89+
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3}, 200),
90+
(AdamP, {'lr': 5e-1, 'weight_decay': 1e-3}, 500),
91+
(DiffGrad, {'lr': 15 - 1, 'weight_decay': 1e-3}, 500),
92+
(DiffRGrad, {'lr': 1e-1, 'weight_decay': 1e-3}, 200),
93+
(Lamb, {'lr': 1e-1, 'weight_decay': 1e-3}, 200),
94+
(RAdam, {'lr': 1e-1, 'weight_decay': 1e-3}, 200),
95+
(SGDP, {'lr': 5e-1, 'weight_decay': 1e-3}, 500),
96+
(Ranger, {'lr': 5e-1, 'weight_decay': 1e-3}, 200),
97+
(Ranger21, {'lr': 5e-1, 'weight_decay': 1e-3, 'num_iterations': 500}, 500),
98+
]
99+
84100

85-
@pytest.mark.parametrize('optimizer_config', OPTIMIZERS, ids=ids)
86-
def test_optimizers(optimizer_config):
101+
@pytest.mark.parametrize('optimizer_fp32_config', FP32_OPTIMIZERS, ids=ids)
102+
def test_f32_optimizers(optimizer_fp32_config):
87103
torch.manual_seed(42)
88104

89105
x_data, y_data = make_dataset()
90106

91107
model: nn.Module = LogisticRegression()
92108
loss_fn: nn.Module = nn.BCEWithLogitsLoss()
93109

94-
optimizer_class, config, iterations = optimizer_config
110+
optimizer_class, config, iterations = optimizer_fp32_config
95111
optimizer = optimizer_class(model.parameters(), **config)
96112

97113
loss: float = np.inf
@@ -110,3 +126,58 @@ def test_optimizers(optimizer_config):
110126
optimizer.step()
111127

112128
assert init_loss > 2.0 * loss
129+
130+
131+
@pytest.mark.parametrize('optimizer_fp16_config', FP16_OPTIMIZERS, ids=ids)
132+
def test_f16_optimizers(optimizer_fp16_config):
133+
torch.manual_seed(42)
134+
135+
x_data, y_data = make_dataset()
136+
137+
model: nn.Module = LogisticRegression()
138+
loss_fn: nn.Module = nn.BCEWithLogitsLoss()
139+
140+
optimizer_class, config, iterations = optimizer_fp16_config
141+
optimizer = SafeFP16Optimizer(optimizer_class(model.parameters(), **config))
142+
143+
loss: float = np.inf
144+
init_loss: float = np.inf
145+
for _ in range(1000):
146+
optimizer.zero_grad()
147+
148+
y_pred = model(x_data)
149+
loss = loss_fn(y_pred, y_data)
150+
151+
if init_loss == np.inf:
152+
init_loss = loss
153+
154+
loss.backward()
155+
156+
optimizer.step()
157+
158+
assert init_loss - 0.01 > loss
159+
160+
161+
@pytest.mark.parametrize('optimizer_config', FP32_OPTIMIZERS, ids=ids)
162+
def test_sam_optimizers(optimizer_config):
163+
torch.manual_seed(42)
164+
165+
x_data, y_data = make_dataset()
166+
167+
model: nn.Module = LogisticRegression()
168+
loss_fn: nn.Module = nn.BCEWithLogitsLoss()
169+
170+
optimizer_class, config, iterations = optimizer_config
171+
optimizer = SAM(model.parameters(), optimizer_class, **config)
172+
173+
loss: float = np.inf
174+
init_loss: float = np.inf
175+
for _ in range(iterations):
176+
loss = loss_fn(y_data, model(x_data))
177+
loss.backward()
178+
optimizer.first_step(zero_grad=True)
179+
180+
loss_fn(y_data, model(x_data)).backward()
181+
optimizer.second_step(zero_grad=True)
182+
183+
assert init_loss > 2.0 * loss

0 commit comments

Comments
 (0)