1+ from typing import Any
2+
13import torch
24from torch import Tensor
35from torch .utils ._pytree import tree_map
@@ -22,10 +24,12 @@ def __new__(cls, data: Tensor, v_to_p: list[int]):
2224 assert not data .requires_grad or not torch .is_grad_enabled ()
2325
2426 shape = [data .shape [i ] for i in v_to_p ]
25- result = Tensor ._make_wrapper_subclass (cls , shape , dtype = data .dtype , device = data .device )
26- result ._data = data # type: ignore
27- result ._v_to_p = v_to_p # type: ignore
28- result ._v_shape = shape # type: ignore
27+ return Tensor ._make_wrapper_subclass (cls , shape , dtype = data .dtype , device = data .device )
28+
29+ def __init__ (self , data : Tensor , v_to_p : list [int ]):
30+ self ._data = data
31+ self ._v_to_p = v_to_p
32+ self ._v_shape = [data .shape [i ] for i in v_to_p ]
2933
3034 def to_dense (self ) -> Tensor :
3135 first_indices = dict [int , int ]()
@@ -46,15 +50,17 @@ def to_dense(self) -> Tensor:
4650 return output
4751
4852 @classmethod
49- def __torch_dispatch__ (cls , func , types , args = (), kwargs = None ):
53+ def __torch_dispatch__ (
54+ cls , func : {__name__ }, types : Any , args : tuple [()] | Any = (), kwargs : Any = None
55+ ):
5056 kwargs = kwargs if kwargs else {}
5157
5258 # TODO: Handle batched operations (apply to self._data and wrap)
5359 # TODO: Handle all operations that can be represented with an einsum by translating them
5460 # to operations on self._data and wrapping accordingly.
5561
5662 # --- Fallback: Fold to Dense Tensor ---
57- def unwrap_to_dense (t ):
63+ def unwrap_to_dense (t : Tensor ):
5864 if isinstance (t , cls ):
5965 return t .to_dense ()
6066 else :
0 commit comments