Skip to content

Commit c8d889e

Browse files
committed
update: test_d_adapt_reset
1 parent e19983d commit c8d889e

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tests/test_optimizers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
make_dataset,
1616
names,
1717
simple_parameter,
18+
simple_sparse_parameter,
1819
tensor_to_numpy,
1920
)
2021

@@ -272,9 +273,12 @@ def test_reset(optimizer_config):
272273
optimizer.reset()
273274

274275

276+
@pytest.mark.parametrize('sparse_gradient', [False, True])
275277
@pytest.mark.parametrize('optimizer_name', ['DAdaptAdaGrad', 'DAdaptAdam', 'DAdaptSGD'])
276-
def test_d_adapt_reset(optimizer_name):
277-
optimizer = load_optimizer(optimizer_name)([simple_parameter()])
278+
def test_d_adapt_reset(sparse_gradient, optimizer_name):
279+
param = simple_sparse_parameter() if sparse_gradient else simple_parameter()
280+
281+
optimizer = load_optimizer(optimizer_name)([param])
278282
assert optimizer.__str__ == optimizer_name
279283
optimizer.reset()
280284

0 commit comments

Comments
 (0)