Skip to content

Commit 7410f83

Browse files
committed
update: test_ranger21_closure
1 parent b1433d0 commit 7410f83

File tree

1 file changed

+28
-2
lines changed

1 file changed

+28
-2
lines changed

tests/test_optimizer_parameters.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import List
22

33
import pytest
4+
import torch
45
from torch import nn
6+
from torch.nn import functional as F
57

68
from pytorch_optimizer import SAM, AdamP, Lookahead, PCGrad, Ranger21, SafeFP16Optimizer, load_optimizer
79
from tests.utils import Example
@@ -211,10 +213,34 @@ def test_ranger21_warm_methods():
211213
assert Ranger21.build_warm_down_iterations(1000) == 280
212214

213215

214-
def test_ranger21_variance_normalized():
216+
def test_ranger21_size_of_parameter():
215217
model: nn.Module = nn.Linear(1, 1, bias=False)
216218
model.requires_grad_(False)
217219

218-
optimizer = Ranger21(model.parameters(), 100)
219220
with pytest.raises(ValueError):
221+
Ranger21(model.parameters(), 100).step()
222+
223+
224+
def test_ranger21_variance_normalized():
225+
model: nn.Module = Example()
226+
optimizer = Ranger21(model.parameters(), num_iterations=100, betas=(0.9, 1e-9))
227+
228+
y_pred = model(torch.ones((1, 1)))
229+
loss = F.binary_cross_entropy_with_logits(y_pred, torch.zeros(1, 1))
230+
231+
with pytest.raises(RuntimeError):
220232
optimizer.step()
233+
234+
235+
def test_ranger21_closure():
236+
model: nn.Module = Example()
237+
optimizer = Ranger21(model.parameters(), num_iterations=100, betas=(0.9, 1e-9))
238+
239+
loss_fn = nn.BCEWithLogitsLoss()
240+
241+
def closure():
242+
loss = loss_fn(torch.ones((1, 1)), model(torch.ones((1, 1))))
243+
loss.backward()
244+
return loss
245+
246+
optimizer.step(closure)

0 commit comments

Comments
 (0)