Skip to content

Commit 04b22d5

Browse files
committed
update: test_sparse_supported
1 parent b7194a3 commit 04b22d5

File tree

1 file changed

+36
-16
lines changed

1 file changed

+36
-16
lines changed

tests/test_gradients.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,39 +43,59 @@ def test_sparse_not_supported(no_sparse_optimizer):
4343
optimizer.step(lambda: 0.1)
4444

4545

46+
@pytest.mark.parametrize('sparse_optimizer', SPARSE_OPTIMIZERS)
47+
def test_sparse(sparse_optimizer):
48+
opt = load_optimizer(optimizer=sparse_optimizer)
49+
50+
weight, weight_sparse = simple_sparse_parameter()
51+
52+
opt_dense = opt([weight], lr=1e-3, momentum=0.0)
53+
opt_sparse = opt([weight_sparse], lr=1e-3, momentum=0.0)
54+
55+
opt_dense.step()
56+
opt_sparse.step()
57+
assert torch.allclose(weight, weight_sparse)
58+
59+
weight.grad = torch.rand_like(weight)
60+
weight.grad[1] = 0.0
61+
weight_sparse.grad = weight.grad.to_sparse()
62+
63+
opt_dense.step()
64+
opt_sparse.step()
65+
assert torch.allclose(weight, weight_sparse)
66+
67+
weight.grad = torch.rand_like(weight)
68+
weight.grad[0] = 0.0
69+
weight_sparse.grad = weight.grad.to_sparse()
70+
71+
opt_dense.step()
72+
opt_sparse.step()
73+
assert torch.allclose(weight, weight_sparse)
74+
75+
4676
@pytest.mark.parametrize('sparse_optimizer', SPARSE_OPTIMIZERS)
4777
def test_sparse_supported(sparse_optimizer):
4878
opt = load_optimizer(optimizer=sparse_optimizer)
4979

50-
optimizer = opt([simple_sparse_parameter()], momentum=0.0)
80+
optimizer = opt([simple_sparse_parameter()[1]], momentum=0.0)
5181
optimizer.zero_grad()
5282
optimizer.step()
5383

54-
optimizer = opt([simple_sparse_parameter()], momentum=0.0)
55-
with pytest.raises(RuntimeError):
56-
optimizer.step()
57-
58-
optimizer = opt([simple_sparse_parameter()], momentum=0.0, eps=0.0)
59-
optimizer.reset()
60-
with pytest.raises(RuntimeError):
61-
optimizer.step()
84+
optimizer = opt([simple_sparse_parameter()[1]], momentum=0.0, eps=0.0)
85+
optimizer.step()
6286

6387
if sparse_optimizer == 'madgrad':
64-
optimizer = opt([simple_sparse_parameter()], momentum=0.0, weight_decay=1e-3, decouple_decay=False)
65-
optimizer.reset()
66-
88+
optimizer = opt([simple_sparse_parameter()[1]], momentum=0.0, weight_decay=1e-3, decouple_decay=False)
6789
with pytest.raises(NoSparseGradientError):
6890
optimizer.step()
6991

70-
optimizer = opt([simple_sparse_parameter()], momentum=0.9, weight_decay=1e-3)
92+
optimizer = opt([simple_sparse_parameter()[1]], momentum=0.9, weight_decay=1e-3)
7193
optimizer.reset()
72-
7394
if sparse_optimizer == 'madgrad':
7495
with pytest.raises(NoSparseGradientError):
7596
optimizer.step()
7697
else:
77-
with pytest.raises(RuntimeError):
78-
optimizer.step()
98+
optimizer.step()
7999

80100

81101
@pytest.mark.parametrize('optimizer_name', VALID_OPTIMIZER_NAMES)

0 commit comments

Comments
 (0)