Skip to content

Commit bda6532

Browse files
committed
fix: test_soap_parameters
1 parent 3b0052f commit bda6532

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tests/test_optimizers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,10 @@ def test_soap_parameters(params):
476476
for _ in range(2):
477477
optimizer.zero_grad()
478478

479-
model[0].weight.grad = None
480-
model[1].weight.grad = torch.randn((1, 8))
479+
model[0].weight.grad = torch.zeros((8, 2))
480+
model[0].bias.grad = torch.zeros((8,))
481+
model[1].weight.grad = torch.zeros((1, 8))
482+
model[1].bias.grad = torch.zeros((1,))
481483

482484
optimizer.step()
483485

0 commit comments

Comments
 (0)