Skip to content

Commit bdc48b1

Browse files
committed
update: simple_parameter
1 parent b797a53 commit bdc48b1

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

tests/test_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def test_reset(optimizer_config):
273273

274274
@pytest.mark.parametrize('optimizer_name', ['DAdaptAdaGrad', 'DAdaptAdam', 'DAdaptSGD'])
275275
def test_d_adapt_reset(optimizer_name):
276-
optimizer = load_optimizer(optimizer_name)(MultiHeadLogisticRegression().parameters())
276+
optimizer = load_optimizer(optimizer_name)([simple_parameter()])
277277
assert optimizer.__str__ == optimizer_name
278278
optimizer.reset()
279279

tests/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4545

4646

4747
def simple_parameter(require_grad: bool = True) -> torch.Tensor:
48-
return torch.zeros(1, 1).requires_grad_(require_grad)
48+
param = torch.randn(1, 1).requires_grad_(require_grad)
49+
param.grad = torch.randn(1, 1)
50+
return param
4951

5052

5153
def simple_sparse_parameter(require_grad: bool = True) -> torch.Tensor:

0 commit comments

Comments
 (0)