Skip to content

Commit f476b29

Browse files
committed
revert removing __init__
1 parent a0b7ffc commit f476b29

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

src/torchjd/autogram/diagonal_sparse_tensor.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any
2+
13
import torch
24
from torch import Tensor
35
from torch.utils._pytree import tree_map
@@ -22,10 +24,12 @@ def __new__(cls, data: Tensor, v_to_p: list[int]):
2224
assert not data.requires_grad or not torch.is_grad_enabled()
2325

2426
shape = [data.shape[i] for i in v_to_p]
25-
result = Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device)
26-
result._data = data # type: ignore
27-
result._v_to_p = v_to_p # type: ignore
28-
result._v_shape = shape # type: ignore
27+
return Tensor._make_wrapper_subclass(cls, shape, dtype=data.dtype, device=data.device)
28+
29+
def __init__(self, data: Tensor, v_to_p: list[int]):
30+
self._data = data
31+
self._v_to_p = v_to_p
32+
self._v_shape = [data.shape[i] for i in v_to_p]
2933

3034
def to_dense(self) -> Tensor:
3135
first_indices = dict[int, int]()
@@ -46,15 +50,17 @@ def to_dense(self) -> Tensor:
4650
return output
4751

4852
@classmethod
49-
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
53+
def __torch_dispatch__(
54+
cls, func: {__name__}, types: Any, args: tuple[()] | Any = (), kwargs: Any = None
55+
):
5056
kwargs = kwargs if kwargs else {}
5157

5258
# TODO: Handle batched operations (apply to self._data and wrap)
5359
# TODO: Handle all operations that can be represented with an einsum by translating them
5460
# to operations on self._data and wrapping accordingly.
5561

5662
# --- Fallback: Fold to Dense Tensor ---
57-
def unwrap_to_dense(t):
63+
def unwrap_to_dense(t: Tensor):
5864
if isinstance(t, cls):
5965
return t.to_dense()
6066
else:

0 commit comments

Comments
 (0)