Skip to content

Commit 58b4226

Browse files
committed
update: simple_sparse_parameter
1 parent 46438e2 commit 58b4226

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tests/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,10 @@ def simple_sparse_parameter(require_grad: bool = True) -> Tuple[torch.Tensor, to
5454
weight = torch.randn(5, 1).requires_grad_(require_grad)
5555
weight_sparse = weight.detach().requires_grad_(require_grad)
5656

57-
weight.grad = torch.rand_like(weight)
58-
weight.grad[0] = 0.0
59-
weight_sparse.grad = weight.grad.to_sparse()
57+
if require_grad:
58+
weight.grad = torch.rand_like(weight)
59+
weight.grad[0] = 0.0
60+
weight_sparse.grad = weight.grad.to_sparse()
6061

6162
return weight, weight_sparse
6263

0 commit comments

Comments
 (0)