| 
2 | 2 | 
 
  | 
3 | 3 | import torch  | 
4 | 4 | from torch import Tensor  | 
 | 5 | +from torch.ops import aten  | 
5 | 6 | from torch.utils._pytree import tree_map  | 
6 | 7 | 
 
  | 
 | 8 | +# pointwise functions applied to one Tensor with `0.0 → 0`  | 
 | 9 | +_pointwise_functions = {  | 
 | 10 | +    aten.abs.default,  | 
 | 11 | +    aten.abs_.default,  | 
 | 12 | +    aten.absolute.default,  | 
 | 13 | +    aten.absolute_.default,  | 
 | 14 | +    aten.neg.default,  | 
 | 15 | +    aten.neg_.default,  | 
 | 16 | +    aten.negative.default,  | 
 | 17 | +    aten.negative_.default,  | 
 | 18 | +    aten.sign.default,  | 
 | 19 | +    aten.sign_.default,  | 
 | 20 | +    aten.sgn.default,  | 
 | 21 | +    aten.sgn_.default,  | 
 | 22 | +    aten.square.default,  | 
 | 23 | +    aten.square_.default,  | 
 | 24 | +    aten.fix.default,  | 
 | 25 | +    aten.fix_.default,  | 
 | 26 | +    aten.floor.default,  | 
 | 27 | +    aten.floor_.default,  | 
 | 28 | +    aten.ceil.default,  | 
 | 29 | +    aten.ceil_.default,  | 
 | 30 | +    aten.trunc.default,  | 
 | 31 | +    aten.trunc_.default,  | 
 | 32 | +    aten.round.default,  | 
 | 33 | +    aten.round_.default,  | 
 | 34 | +    aten.positive.default,  | 
 | 35 | +    aten.expm1.default,  | 
 | 36 | +    aten.expm1_.default,  | 
 | 37 | +    aten.log1p.default,  | 
 | 38 | +    aten.log1p_.default,  | 
 | 39 | +    aten.sqrt.default,  | 
 | 40 | +    aten.sqrt_.default,  | 
 | 41 | +    aten.sin.default,  | 
 | 42 | +    aten.sin_.default,  | 
 | 43 | +    aten.tan.default,  | 
 | 44 | +    aten.tan_.default,  | 
 | 45 | +    aten.sinh.default,  | 
 | 46 | +    aten.sinh_.default,  | 
 | 47 | +    aten.tanh.default,  | 
 | 48 | +    aten.tanh_.default,  | 
 | 49 | +    aten.asin.default,  | 
 | 50 | +    aten.asin_.default,  | 
 | 51 | +    aten.atan.default,  | 
 | 52 | +    aten.atan_.default,  | 
 | 53 | +    aten.asinh.default,  | 
 | 54 | +    aten.asinh_.default,  | 
 | 55 | +    aten.atanh.default,  | 
 | 56 | +    aten.atanh_.default,  | 
 | 57 | +    aten.erf.default,  | 
 | 58 | +    aten.erf_.default,  | 
 | 59 | +    aten.erfinv.default,  | 
 | 60 | +    aten.erfinv_.default,  | 
 | 61 | +    aten.relu.default,  | 
 | 62 | +    aten.relu_.default,  | 
 | 63 | +    aten.hardtanh.default,  | 
 | 64 | +    aten.hardtanh_.default,  | 
 | 65 | +    aten.leaky_relu.default,  | 
 | 66 | +    aten.leaky_relu_.default,  | 
 | 67 | +}  | 
 | 68 | + | 
7 | 69 | 
 
  | 
8 | 70 | class DiagonalSparseTensor(torch.Tensor):  | 
9 | 71 | 
 
  | 
@@ -50,10 +112,19 @@ def to_dense(self) -> Tensor:  | 
50 | 112 |         return output  | 
51 | 113 | 
 
  | 
52 | 114 |     @classmethod  | 
53 |  | -    def __torch_dispatch__(  | 
54 |  | -        cls, func: {__name__}, types: Any, args: tuple[()] | Any = (), kwargs: Any = None  | 
55 |  | -    ):  | 
56 |  | -        kwargs = kwargs if kwargs else {}  | 
 | 115 | +    def __torch_dispatch__(cls, func: {__name__}, types: Any, args: tuple = (), kwargs: Any = None):  | 
 | 116 | +        kwargs = {} if kwargs is None else kwargs  | 
 | 117 | + | 
 | 118 | +        # If `func` is a pointwise operator that applies to a single Tensor and such that func(0)=0  | 
 | 119 | +        # Then we can apply the transformation to self._data and wrap the result.  | 
 | 120 | +        if func in _pointwise_functions:  | 
 | 121 | +            assert (  | 
 | 122 | +                isinstance(args, tuple) and len(args) == 1 and func(torch.zeros([])).item() == 0.0  | 
 | 123 | +            )  | 
 | 124 | +            sparse_tensor = args[0]  | 
 | 125 | +            assert isinstance(sparse_tensor, DiagonalSparseTensor)  | 
 | 126 | +            new_data = func(sparse_tensor._data)  | 
 | 127 | +            return DiagonalSparseTensor(new_data, sparse_tensor._v_to_p)  | 
57 | 128 | 
 
  | 
58 | 129 |         # TODO: Handle batched operations (apply to self._data and wrap)  | 
59 | 130 |         # TODO: Handle all operations that can be represented with an einsum by translating them  | 
 | 
0 commit comments