|
1 | 1 | from typing import List |
2 | 2 |
|
3 | 3 | import pytest |
| 4 | +import torch |
4 | 5 | from torch import nn |
| 6 | +from torch.nn import functional as F |
5 | 7 |
|
6 | 8 | from pytorch_optimizer import SAM, AdamP, Lookahead, PCGrad, Ranger21, SafeFP16Optimizer, load_optimizer |
7 | 9 | from tests.utils import Example |
@@ -211,10 +213,34 @@ def test_ranger21_warm_methods(): |
211 | 213 | assert Ranger21.build_warm_down_iterations(1000) == 280 |
212 | 214 |
|
213 | 215 |
|
214 | | -def test_ranger21_variance_normalized(): |
| 216 | +def test_ranger21_size_of_parameter(): |
215 | 217 | model: nn.Module = nn.Linear(1, 1, bias=False) |
216 | 218 | model.requires_grad_(False) |
217 | 219 |
|
218 | | - optimizer = Ranger21(model.parameters(), 100) |
219 | 220 | with pytest.raises(ValueError): |
| 221 | + Ranger21(model.parameters(), 100).step() |
| 222 | + |
| 223 | + |
| 224 | +def test_ranger21_variance_normalized(): |
| 225 | + model: nn.Module = Example() |
| 226 | + optimizer = Ranger21(model.parameters(), num_iterations=100, betas=(0.9, 1e-9)) |
| 227 | + |
| 228 | + y_pred = model(torch.ones((1, 1))) |
| 229 | + loss = F.binary_cross_entropy_with_logits(y_pred, torch.zeros(1, 1)) |
| 230 | + |
| 231 | + with pytest.raises(RuntimeError): |
220 | 232 | optimizer.step() |
| 233 | + |
| 234 | + |
| 235 | +def test_ranger21_closure(): |
| 236 | + model: nn.Module = Example() |
| 237 | + optimizer = Ranger21(model.parameters(), num_iterations=100, betas=(0.9, 1e-9)) |
| 238 | + |
| 239 | + loss_fn = nn.BCEWithLogitsLoss() |
| 240 | + |
| 241 | + def closure(): |
| 242 | + loss = loss_fn(torch.ones((1, 1)), model(torch.ones((1, 1)))) |
| 243 | + loss.backward() |
| 244 | + return loss |
| 245 | + |
| 246 | + optimizer.step(closure) |
0 commit comments