We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 30f648f commit afa33edCopy full SHA for afa33ed
tests/test_sparse_gradient.py
@@ -5,6 +5,10 @@
5
6
from pytorch_optimizer import load_optimizers
7
8
+SPARSE_OPTIMIZERS: List[str] = [
9
+ 'madgrad',
10
+]
11
+
12
NO_SPARSE_OPTIMIZERS: List[str] = [
13
'adamp',
14
'sgdp',
@@ -32,3 +36,14 @@ def test_sparse_not_supported(no_sparse_optimizer):
32
36
33
37
with pytest.raises(RuntimeError):
34
38
optimizer.step()
39
40
41
+@pytest.mark.parametrize('sparse_optimizer', SPARSE_OPTIMIZERS)
42
+def test_sparse_supported(sparse_optimizer):
43
+ param = torch.randn(1, 1).to_sparse(1).requires_grad_(True)
44
+ grad = torch.randn(1, 1).to_sparse(1)
45
+ param.grad = grad
46
47
+ optimizer = load_optimizers(optimizer=sparse_optimizer)([param], momentum=0.0)
48
+ optimizer.zero_grad()
49
+ optimizer.step()
0 commit comments