Skip to content

Commit 85556a8

Browse files
committed
Add decorator to handle other functions. Add two examples of such functions (mean and sum).
1 parent 447d714 commit 85556a8

File tree

2 files changed

+50
-11
lines changed

2 files changed

+50
-11
lines changed

src/torchjd/autogram/diagonal_sparse_tensor.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch.utils._pytree import tree_map
77

88
# pointwise functions applied to one Tensor with `0.0 → 0`
9-
_pointwise_functions = {
9+
_POINTWISE_FUNCTIONS = {
1010
aten.abs.default,
1111
aten.abs_.default,
1212
aten.absolute.default,
@@ -65,6 +65,19 @@
6565
aten.leaky_relu.default,
6666
aten.leaky_relu_.default,
6767
}
68+
_HANDLED_FUNCTIONS = dict()
69+
import functools
70+
71+
72+
def implements(torch_function):
73+
"""Register a torch function override for ScalarTensor"""
74+
75+
def decorator(func):
76+
functools.update_wrapper(func, torch_function)
77+
_HANDLED_FUNCTIONS[torch_function] = func
78+
return func
79+
80+
return decorator
6881

6982

7083
class DiagonalSparseTensor(torch.Tensor):
@@ -85,6 +98,10 @@ def __new__(cls, data: Tensor, v_to_p: list[int]):
8598
# (which is bad!)
8699
assert not data.requires_grad or not torch.is_grad_enabled()
87100

101+
# TODO: assert a minimal data, all of its dimensions must be used at least once
102+
# TODO: If no repeat in v_to_p, return a view of data (non sparse tensor). If this cannot be
103+
# done in __new__, create a helper function for that, and use this one everywhere.
104+
88105
shape = [data.shape[i] for i in v_to_p]
89106
return Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device)
90107

@@ -117,7 +134,7 @@ def __torch_dispatch__(cls, func: {__name__}, types: Any, args: tuple = (), kwar
117134

118135
# If `func` is a pointwise operator that applies to a single Tensor and such that func(0)=0
119136
# Then we can apply the transformation to self._data and wrap the result.
120-
if func in _pointwise_functions:
137+
if func in _POINTWISE_FUNCTIONS:
121138
assert (
122139
isinstance(args, tuple) and len(args) == 1 and func(torch.zeros([])).item() == 0.0
123140
)
@@ -126,9 +143,8 @@ def __torch_dispatch__(cls, func: {__name__}, types: Any, args: tuple = (), kwar
126143
new_data = func(sparse_tensor._data)
127144
return DiagonalSparseTensor(new_data, sparse_tensor._v_to_p)
128145

129-
# TODO: Handle batched operations (apply to self._data and wrap)
130-
# TODO: Handle all operations that can be represented with an einsum by translating them
131-
# to operations on self._data and wrapping accordingly.
146+
if func in _HANDLED_FUNCTIONS:
147+
return _HANDLED_FUNCTIONS[func](*args, **kwargs)
132148

133149
# --- Fallback: Fold to Dense Tensor ---
134150
def unwrap_to_dense(t: Tensor):
@@ -145,3 +161,15 @@ def __repr__(self):
145161
f"DiagonalSparseTensor(data={self._data}, v_to_p_map={self._v_to_p}, shape="
146162
f"{self._v_shape})"
147163
)
164+
165+
166+
@implements(aten.mean.default)
167+
def mean_default(t: Tensor) -> Tensor:
168+
assert isinstance(t, DiagonalSparseTensor)
169+
return aten.sum.default(t._data) / t.numel()
170+
171+
172+
@implements(aten.sum.default)
173+
def sum_default(t: Tensor) -> Tensor:
174+
assert isinstance(t, DiagonalSparseTensor)
175+
return aten.sum.default(t._data)

tests/unit/autogram/test_diagonal_sparse_tensor.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch.testing import assert_close
44
from utils.tensors import randn_, zeros_
55

6-
from torchjd.autogram.diagonal_sparse_tensor import DiagonalSparseTensor, _pointwise_functions
6+
from torchjd.autogram.diagonal_sparse_tensor import _POINTWISE_FUNCTIONS, DiagonalSparseTensor
77

88

99
@mark.parametrize(
@@ -47,14 +47,25 @@ def test_three_virtual_single_physical():
4747
assert_close(b, expected)
4848

4949

50-
@mark.parametrize("func", _pointwise_functions)
50+
@mark.parametrize("func", _POINTWISE_FUNCTIONS)
5151
def test_pointwise(func):
52-
dim = 100
52+
dim = 10
5353
a = randn_([dim])
5454
b = DiagonalSparseTensor(a, [0, 0])
5555
c = b.to_dense()
56-
d = func(b)
57-
assert isinstance(d, DiagonalSparseTensor)
56+
res = func(b)
57+
assert isinstance(res, DiagonalSparseTensor)
5858

5959
# need to be careful about nans
60-
assert_close(d, func(c))
60+
assert_close(res, func(c))
61+
62+
63+
@mark.parametrize("func", [torch.mean, torch.sum])
64+
def test_mean(func):
65+
dim = 10
66+
a = randn_([dim])
67+
b = DiagonalSparseTensor(a, [0, 0])
68+
c = b.to_dense()
69+
70+
mean = func(b)
71+
assert_close(mean, func(c))

0 commit comments

Comments
 (0)