Skip to content

Commit b7194a3

Browse files
committed
update: simple_sparse_parameter
1 parent 37524ed commit b7194a3

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

tests/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ def simple_parameter(require_grad: bool = True) -> torch.Tensor:
5353
def simple_sparse_parameter(require_grad: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
5454
weight = torch.randn(5, 1).requires_grad_(require_grad)
5555
weight_sparse = weight.detach().requires_grad_(require_grad)
56+
57+
weight.grad = torch.rand_like(weight)
58+
weight.grad[0] = 0.0
59+
weight_sparse.grad = weight.grad.to_sparse()
60+
5661
return weight, weight_sparse
5762

5863

0 commit comments

Comments
 (0)