@@ -43,39 +43,59 @@ def test_sparse_not_supported(no_sparse_optimizer):
4343 optimizer .step (lambda : 0.1 )
4444
4545
46+ @pytest .mark .parametrize ('sparse_optimizer' , SPARSE_OPTIMIZERS )
47+ def test_sparse (sparse_optimizer ):
48+ opt = load_optimizer (optimizer = sparse_optimizer )
49+
50+ weight , weight_sparse = simple_sparse_parameter ()
51+
52+ opt_dense = opt ([weight ], lr = 1e-3 , momentum = 0.0 )
53+ opt_sparse = opt ([weight_sparse ], lr = 1e-3 , momentum = 0.0 )
54+
55+ opt_dense .step ()
56+ opt_sparse .step ()
57+ assert torch .allclose (weight , weight_sparse )
58+
59+ weight .grad = torch .rand_like (weight )
60+ weight .grad [1 ] = 0.0
61+ weight_sparse .grad = weight .grad .to_sparse ()
62+
63+ opt_dense .step ()
64+ opt_sparse .step ()
65+ assert torch .allclose (weight , weight_sparse )
66+
67+ weight .grad = torch .rand_like (weight )
68+ weight .grad [0 ] = 0.0
69+ weight_sparse .grad = weight .grad .to_sparse ()
70+
71+ opt_dense .step ()
72+ opt_sparse .step ()
73+ assert torch .allclose (weight , weight_sparse )
74+
75+
4676@pytest .mark .parametrize ('sparse_optimizer' , SPARSE_OPTIMIZERS )
4777def test_sparse_supported (sparse_optimizer ):
4878 opt = load_optimizer (optimizer = sparse_optimizer )
4979
50- optimizer = opt ([simple_sparse_parameter ()], momentum = 0.0 )
80+ optimizer = opt ([simple_sparse_parameter ()[ 1 ] ], momentum = 0.0 )
5181 optimizer .zero_grad ()
5282 optimizer .step ()
5383
54- optimizer = opt ([simple_sparse_parameter ()], momentum = 0.0 )
55- with pytest .raises (RuntimeError ):
56- optimizer .step ()
57-
58- optimizer = opt ([simple_sparse_parameter ()], momentum = 0.0 , eps = 0.0 )
59- optimizer .reset ()
60- with pytest .raises (RuntimeError ):
61- optimizer .step ()
84+ optimizer = opt ([simple_sparse_parameter ()[1 ]], momentum = 0.0 , eps = 0.0 )
85+ optimizer .step ()
6286
6387 if sparse_optimizer == 'madgrad' :
64- optimizer = opt ([simple_sparse_parameter ()], momentum = 0.0 , weight_decay = 1e-3 , decouple_decay = False )
65- optimizer .reset ()
66-
88+ optimizer = opt ([simple_sparse_parameter ()[1 ]], momentum = 0.0 , weight_decay = 1e-3 , decouple_decay = False )
6789 with pytest .raises (NoSparseGradientError ):
6890 optimizer .step ()
6991
70- optimizer = opt ([simple_sparse_parameter ()], momentum = 0.9 , weight_decay = 1e-3 )
92+ optimizer = opt ([simple_sparse_parameter ()[ 1 ] ], momentum = 0.9 , weight_decay = 1e-3 )
7193 optimizer .reset ()
72-
7394 if sparse_optimizer == 'madgrad' :
7495 with pytest .raises (NoSparseGradientError ):
7596 optimizer .step ()
7697 else :
77- with pytest .raises (RuntimeError ):
78- optimizer .step ()
98+ optimizer .step ()
7999
80100
81101@pytest .mark .parametrize ('optimizer_name' , VALID_OPTIMIZER_NAMES )
0 commit comments