Skip to content

Commit afa33ed

Browse files
committed
update: test_sparse_supported
1 parent 30f648f commit afa33ed

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

tests/test_sparse_gradient.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55

66
from pytorch_optimizer import load_optimizers
77

8+
SPARSE_OPTIMIZERS: List[str] = [
9+
'madgrad',
10+
]
11+
812
NO_SPARSE_OPTIMIZERS: List[str] = [
913
'adamp',
1014
'sgdp',
@@ -32,3 +36,14 @@ def test_sparse_not_supported(no_sparse_optimizer):
3236

3337
with pytest.raises(RuntimeError):
3438
optimizer.step()
39+
40+
41+
@pytest.mark.parametrize('sparse_optimizer', SPARSE_OPTIMIZERS)
42+
def test_sparse_supported(sparse_optimizer):
43+
param = torch.randn(1, 1).to_sparse(1).requires_grad_(True)
44+
grad = torch.randn(1, 1).to_sparse(1)
45+
param.grad = grad
46+
47+
optimizer = load_optimizers(optimizer=sparse_optimizer)([param], momentum=0.0)
48+
optimizer.zero_grad()
49+
optimizer.step()

0 commit comments

Comments
 (0)