Skip to content

Commit d341714

Browse files
committed
ci(test): test_sam_no_gradient
1 parent 8721654 commit d341714

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

tests/test_optimizers.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -376,12 +376,18 @@ def test_no_closure():
376376

377377

378378
def test_sam_no_gradient():
379-
param = torch.randn(1, 1).requires_grad_(False)
380-
param.grad = torch.randn(1, 1)
379+
(x_data, y_data), model, loss_fn = build_environment()
380+
model.fc1.require_grads = False
381381

382-
optimizer = SAM([param], AdamP)
382+
optimizer = SAM(model.parameters(), AdamP)
383383
optimizer.zero_grad()
384-
optimizer.step(closure=dummy_closure)
384+
385+
loss = loss_fn(y_data, model(x_data))
386+
loss.backward()
387+
optimizer.first_step(zero_grad=True)
388+
389+
loss_fn(y_data, model(x_data)).backward()
390+
optimizer.second_step(zero_grad=True)
385391

386392

387393
@pytest.mark.parametrize('optimizer_config', OPTIMIZERS, ids=ids)

0 commit comments

Comments
 (0)