Skip to content

Commit ffeb153

Browse files
committed
update: simple_sparse_parameter
1 parent fdd6f80 commit ffeb153

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ def simple_parameter(require_grad: bool = True) -> torch.Tensor:
5050
return param
5151

5252

53-
def simple_sparse_parameter(require_grad: bool = True) -> torch.Tensor:
54-
param = torch.randn(2, 2).to_sparse(1).requires_grad_(require_grad)
55-
param.grad = torch.randn(2, 2).to_sparse(1)
56-
return param
53+
def simple_sparse_parameter(require_grad: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
54+
weight = torch.randn(5, 1).requires_grad_(require_grad)
55+
weight_sparse = weight.detach().requires_grad_(require_grad)
56+
return weight, weight_sparse
5757

5858

5959
def make_dataset(num_samples: int = 100, dims: int = 2, seed: int = 42) -> Tuple[torch.Tensor, torch.Tensor]:

0 commit comments

Comments
 (0)