Skip to content

Commit a2067a5

Browse files
committed
refactor: grad
1 parent 534c200 commit a2067a5

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

tests/test_sparse_gradient.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929
@pytest.mark.parametrize('no_sparse_optimizer', NO_SPARSE_OPTIMIZERS)
3030
def test_sparse_not_supported(no_sparse_optimizer):
3131
param = torch.randn(1, 1).to_sparse(1).requires_grad_(True)
32-
grad = torch.randn(1, 1).to_sparse(1)
33-
param.grad = grad
32+
param.grad = torch.randn(1, 1).to_sparse(1)
3433

3534
with pytest.raises(RuntimeError):
3635
optimizer = load_optimizers(optimizer=no_sparse_optimizer)([param])
@@ -41,8 +40,7 @@ def test_sparse_not_supported(no_sparse_optimizer):
4140
@pytest.mark.parametrize('sparse_optimizer', SPARSE_OPTIMIZERS)
4241
def test_sparse_supported(sparse_optimizer):
4342
param = torch.randn(1, 1).to_sparse(1).requires_grad_(True)
44-
grad = torch.randn(1, 1).to_sparse(1)
45-
param.grad = grad
43+
param.grad = torch.randn(1, 1).to_sparse(1)
4644

4745
optimizer = load_optimizers(optimizer=sparse_optimizer)([param], momentum=0.0)
4846
optimizer.zero_grad()

0 commit comments

Comments
 (0)