Skip to content

Commit 447d714

Browse files
committed
Give implementation for pointwise
1 parent f476b29 commit 447d714

File tree

2 files changed

+104
-7
lines changed

2 files changed

+104
-7
lines changed

src/torchjd/autogram/diagonal_sparse_tensor.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,70 @@
22

33
import torch
44
from torch import Tensor
5+
from torch.ops import aten
56
from torch.utils._pytree import tree_map
67

8+
# pointwise functions applied to one Tensor with `0.0 → 0`
9+
_pointwise_functions = {
10+
aten.abs.default,
11+
aten.abs_.default,
12+
aten.absolute.default,
13+
aten.absolute_.default,
14+
aten.neg.default,
15+
aten.neg_.default,
16+
aten.negative.default,
17+
aten.negative_.default,
18+
aten.sign.default,
19+
aten.sign_.default,
20+
aten.sgn.default,
21+
aten.sgn_.default,
22+
aten.square.default,
23+
aten.square_.default,
24+
aten.fix.default,
25+
aten.fix_.default,
26+
aten.floor.default,
27+
aten.floor_.default,
28+
aten.ceil.default,
29+
aten.ceil_.default,
30+
aten.trunc.default,
31+
aten.trunc_.default,
32+
aten.round.default,
33+
aten.round_.default,
34+
aten.positive.default,
35+
aten.expm1.default,
36+
aten.expm1_.default,
37+
aten.log1p.default,
38+
aten.log1p_.default,
39+
aten.sqrt.default,
40+
aten.sqrt_.default,
41+
aten.sin.default,
42+
aten.sin_.default,
43+
aten.tan.default,
44+
aten.tan_.default,
45+
aten.sinh.default,
46+
aten.sinh_.default,
47+
aten.tanh.default,
48+
aten.tanh_.default,
49+
aten.asin.default,
50+
aten.asin_.default,
51+
aten.atan.default,
52+
aten.atan_.default,
53+
aten.asinh.default,
54+
aten.asinh_.default,
55+
aten.atanh.default,
56+
aten.atanh_.default,
57+
aten.erf.default,
58+
aten.erf_.default,
59+
aten.erfinv.default,
60+
aten.erfinv_.default,
61+
aten.relu.default,
62+
aten.relu_.default,
63+
aten.hardtanh.default,
64+
aten.hardtanh_.default,
65+
aten.leaky_relu.default,
66+
aten.leaky_relu_.default,
67+
}
68+
769

870
class DiagonalSparseTensor(torch.Tensor):
971

@@ -50,10 +112,19 @@ def to_dense(self) -> Tensor:
50112
return output
51113

52114
@classmethod
53-
def __torch_dispatch__(
54-
cls, func: {__name__}, types: Any, args: tuple[()] | Any = (), kwargs: Any = None
55-
):
56-
kwargs = kwargs if kwargs else {}
115+
def __torch_dispatch__(cls, func: {__name__}, types: Any, args: tuple = (), kwargs: Any = None):
116+
kwargs = {} if kwargs is None else kwargs
117+
118+
# If `func` is a pointwise operator that applies to a single Tensor and such that func(0)=0
119+
# Then we can apply the transformation to self._data and wrap the result.
120+
if func in _pointwise_functions:
121+
assert (
122+
isinstance(args, tuple) and len(args) == 1 and func(torch.zeros([])).item() == 0.0
123+
)
124+
sparse_tensor = args[0]
125+
assert isinstance(sparse_tensor, DiagonalSparseTensor)
126+
new_data = func(sparse_tensor._data)
127+
return DiagonalSparseTensor(new_data, sparse_tensor._v_to_p)
57128

58129
# TODO: Handle batched operations (apply to self._data and wrap)
59130
# TODO: Handle all operations that can be represented with an einsum by translating them

tests/unit/autogram/test_diagonal_sparse_tensor.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import torch
22
from pytest import mark
33
from torch.testing import assert_close
4+
from utils.tensors import randn_, zeros_
45

5-
from torchjd.autogram.diagonal_sparse_tensor import DiagonalSparseTensor
6+
from torchjd.autogram.diagonal_sparse_tensor import DiagonalSparseTensor, _pointwise_functions
67

78

89
@mark.parametrize(
@@ -18,17 +19,42 @@
1819
],
1920
)
2021
def test_diagonal_spase_tensor_scalar(shape: list[int]):
21-
a = torch.randn(shape)
22+
a = randn_(shape)
2223
b = DiagonalSparseTensor(a, list(range(len(shape))))
2324

2425
assert_close(a, b)
2526

2627

2728
@mark.parametrize("dim", [1, 2, 3, 4, 5, 10])
2829
def test_diag_equivalence(dim: int):
29-
a = torch.randn([dim])
30+
a = randn_([dim])
3031
b = DiagonalSparseTensor(a, [0, 0])
3132

3233
diag_a = torch.diag(a)
3334

3435
assert_close(b, diag_a)
36+
37+
38+
def test_three_virtual_single_physical():
39+
dim = 10
40+
a = randn_([dim])
41+
b = DiagonalSparseTensor(a, [0, 0, 0])
42+
43+
expected = zeros_([dim, dim, dim])
44+
for i in range(dim):
45+
expected[i, i, i] = a[i]
46+
47+
assert_close(b, expected)
48+
49+
50+
@mark.parametrize("func", _pointwise_functions)
51+
def test_pointwise(func):
52+
dim = 100
53+
a = randn_([dim])
54+
b = DiagonalSparseTensor(a, [0, 0])
55+
c = b.to_dense()
56+
d = func(b)
57+
assert isinstance(d, DiagonalSparseTensor)
58+
59+
# need to be careful about nans
60+
assert_close(d, func(c))

0 commit comments

Comments
 (0)