66from  torch .utils ._pytree  import  tree_map 
77
88# pointwise functions applied to one Tensor with `0.0 → 0` 
9- _pointwise_functions  =  {
9+ _POINTWISE_FUNCTIONS  =  {
1010    aten .abs .default ,
1111    aten .abs_ .default ,
1212    aten .absolute .default ,
6565    aten .leaky_relu .default ,
6666    aten .leaky_relu_ .default ,
6767}
68+ _HANDLED_FUNCTIONS  =  dict ()
69+ import  functools 
70+ 
71+ 
72+ def  implements (torch_function ):
73+     """Register a torch function override for ScalarTensor""" 
74+ 
75+     def  decorator (func ):
76+         functools .update_wrapper (func , torch_function )
77+         _HANDLED_FUNCTIONS [torch_function ] =  func 
78+         return  func 
79+ 
80+     return  decorator 
6881
6982
7083class  DiagonalSparseTensor (torch .Tensor ):
@@ -85,6 +98,10 @@ def __new__(cls, data: Tensor, v_to_p: list[int]):
8598        # (which is bad!) 
8699        assert  not  data .requires_grad  or  not  torch .is_grad_enabled ()
87100
101+         # TODO: assert a minimal data, all of its dimensions must be used at least once 
102+         # TODO: If no repeat in v_to_p, return a view of data (non sparse tensor). If this cannot be 
103+         #  done in __new__, create a helper function for that, and use this one everywhere. 
104+ 
88105        shape  =  [data .shape [i ] for  i  in  v_to_p ]
89106        return  Tensor ._make_wrapper_subclass (cls , shape , dtype = data .dtype , device = data .device )
90107
@@ -117,7 +134,7 @@ def __torch_dispatch__(cls, func: {__name__}, types: Any, args: tuple = (), kwar
117134
118135        # If `func` is a pointwise operator that applies to a single Tensor and such that func(0)=0 
119136        # Then we can apply the transformation to self._data and wrap the result. 
120-         if  func  in  _pointwise_functions :
137+         if  func  in  _POINTWISE_FUNCTIONS :
121138            assert  (
122139                isinstance (args , tuple ) and  len (args ) ==  1  and  func (torch .zeros ([])).item () ==  0.0 
123140            )
@@ -126,9 +143,8 @@ def __torch_dispatch__(cls, func: {__name__}, types: Any, args: tuple = (), kwar
126143            new_data  =  func (sparse_tensor ._data )
127144            return  DiagonalSparseTensor (new_data , sparse_tensor ._v_to_p )
128145
129-         # TODO: Handle batched operations (apply to self._data and wrap) 
130-         # TODO: Handle all operations that can be represented with an einsum by translating them 
131-         #  to operations on self._data and wrapping accordingly. 
146+         if  func  in  _HANDLED_FUNCTIONS :
147+             return  _HANDLED_FUNCTIONS [func ](* args , ** kwargs )
132148
133149        # --- Fallback: Fold to Dense Tensor --- 
134150        def  unwrap_to_dense (t : Tensor ):
@@ -145,3 +161,15 @@ def __repr__(self):
145161            f"DiagonalSparseTensor(data={ self ._data }  , v_to_p_map={ self ._v_to_p }  , shape=" 
146162            f"{ self ._v_shape }  )" 
147163        )
164+ 
165+ 
166+ @implements (aten .mean .default ) 
167+ def  mean_default (t : Tensor ) ->  Tensor :
168+     assert  isinstance (t , DiagonalSparseTensor )
169+     return  aten .sum .default (t ._data ) /  t .numel ()
170+ 
171+ 
172+ @implements (aten .sum .default ) 
173+ def  sum_default (t : Tensor ) ->  Tensor :
174+     assert  isinstance (t , DiagonalSparseTensor )
175+     return  aten .sum .default (t ._data )
0 commit comments