Skip to content

Commit e776f61

Browse files
committed
update: test_d_adapt_reset
1 parent c8d889e commit e776f61

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

tests/test_optimizers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,13 @@ def test_reset(optimizer_config):
273273
optimizer.reset()
274274

275275

276+
@pytest.mark.parametrize('require_gradient', [False, True])
276277
@pytest.mark.parametrize('sparse_gradient', [False, True])
277278
@pytest.mark.parametrize('optimizer_name', ['DAdaptAdaGrad', 'DAdaptAdam', 'DAdaptSGD'])
278-
def test_d_adapt_reset(sparse_gradient, optimizer_name):
279-
param = simple_sparse_parameter() if sparse_gradient else simple_parameter()
279+
def test_d_adapt_reset(require_gradient, sparse_gradient, optimizer_name):
280+
param = simple_sparse_parameter(require_gradient) if sparse_gradient else simple_parameter(require_gradient)
281+
if not require_gradient:
282+
param.grad = None
280283

281284
optimizer = load_optimizer(optimizer_name)([param])
282285
assert optimizer.__str__ == optimizer_name

tests/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4545

4646

4747
def simple_parameter(require_grad: bool = True) -> torch.Tensor:
48-
param = torch.randn(1, 1).requires_grad_(require_grad)
49-
param.grad = torch.randn(1, 1)
48+
param = torch.zeros(1, 1).requires_grad_(require_grad)
49+
param.grad = torch.zeros(1, 1)
5050
return param
5151

5252

0 commit comments

Comments
 (0)