@@ -37,25 +37,11 @@ def __new__(cls, data: Tensor, v_to_p: list[int]):
3737 # official API. I am thinking that something that does the
3838 # assert above and this call could be made into a utility function
3939 # that is in the public API
40- return Tensor ._make_wrapper_subclass (
41- cls , [data .shape [i ] for i in v_to_p ], dtype = data .dtype , device = data .device
42- )
43-
44- def __init__ (self , data : Tensor , v_to_p : list [int ]):
45- """
46- Represent a diagonal sparse tensor.
47-
48- :param data: The physical contiguous data.
49- :param v_to_p: Maps virtual dimensions to physical dimensions.
50-
51- An example is `data` of shape `[d_1, d_2, d_3]` and `v_to_p` equal to `[0, 1, 0, 2, 1]`
52- means the virtual shape is `[d_1, d_2, d_1, d_3, d_2]` and the represented Tensor, indexed
53- at `[i, j, k, l, m]` is `0.` unless `i==k` and `j==m`.
54- """
55- # Deliberate omission of `super().__init__()` as we have an unfaithful data.
56- self ._data = data
57- self ._v_to_p = v_to_p
58- self ._v_shape = tuple (data .shape [i ] for i in v_to_p )
40+ shape = [data .shape [i ] for i in v_to_p ]
41+ result = Tensor ._make_wrapper_subclass (cls , shape , dtype = data .dtype , device = data .device )
42+ result ._data = data # type: ignore
43+ result ._v_to_p = v_to_p # type: ignore
44+ result ._v_shape = shape # type: ignore
5945
6046 def to_dense (self ) -> Tensor :
6147 first_indices = dict [int , int ]()
0 commit comments