We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 84be8c4 commit 9737e2bCopy full SHA for 9737e2b
tests/test_sparse_gradient.py
@@ -14,6 +14,7 @@
14
'sgdp',
15
'madgrad',
16
'ranger',
17
+ 'ranger21',
18
'radam',
19
'adabound',
20
'adahessian',
@@ -31,9 +32,15 @@ def test_sparse_not_supported(no_sparse_optimizer):
31
32
param = torch.randn(1, 1).to_sparse(1).requires_grad_(True)
33
param.grad = torch.randn(1, 1).to_sparse(1)
34
35
+ optimizer = load_optimizers(optimizer=no_sparse_optimizer)
36
+ if no_sparse_optimizer == 'ranger21':
37
+ optimizer = optimizer([param], num_iterations=1)
38
+ else:
39
+ optimizer = optimizer([param])
40
+
41
+ optimizer.zero_grad()
42
43
with pytest.raises(RuntimeError):
- optimizer = load_optimizers(optimizer=no_sparse_optimizer)([param])
- optimizer.zero_grad()
44
optimizer.step()
45
46
0 commit comments