We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 37524ed commit b7194a3Copy full SHA for b7194a3
tests/utils.py
@@ -53,6 +53,11 @@ def simple_parameter(require_grad: bool = True) -> torch.Tensor:
53
def simple_sparse_parameter(require_grad: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
54
weight = torch.randn(5, 1).requires_grad_(require_grad)
55
weight_sparse = weight.detach().requires_grad_(require_grad)
56
+
57
+ weight.grad = torch.rand_like(weight)
58
+ weight.grad[0] = 0.0
59
+ weight_sparse.grad = weight.grad.to_sparse()
60
61
return weight, weight_sparse
62
63
0 commit comments