Skip to content

Commit efa8019

Browse files
committed
improve
1 parent 7f5c097 commit efa8019

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/torchjd/autogram/diagonal_sparse_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def to_dense(self) -> Tensor:
125125
identity_matrices[j] = torch.eye(self._v_shape[i], device=device, dtype=dtype)
126126
einsum_args += [identity_matrices[j], [first_indices[j], i]]
127127

128+
# Need to be careful about nans, we would want to get identity times nan.
128129
output = torch.einsum(*einsum_args, output_indices)
129130
return output
130131

tests/unit/autogram/test_diagonal_sparse_tensor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,11 @@ def test_pointwise(func):
5656
res = func(b)
5757
assert isinstance(res, DiagonalSparseTensor)
5858

59-
# need to be careful about nans
60-
assert_close(res, func(c))
59+
assert_close(res, func(c), equal_nan=True)
6160

6261

6362
@mark.parametrize("func", [torch.mean, torch.sum])
64-
def test_mean(func):
63+
def test_unary(func):
6564
dim = 10
6665
a = randn_([dim])
6766
b = DiagonalSparseTensor(a, [0, 0])

0 commit comments

Comments
 (0)