8
8
import numpy as np
9
9
import scipy .sparse as sps
10
10
11
- from ._common import _hold_self_ref_in_ret , _take_owneship , fn_cache
11
+ from ._common import RefableList , _hold_self_ref_in_ret , _take_owneship , fn_cache
12
12
from ._core import ctx , libc
13
13
from ._dtypes import DType , asdtype
14
14
@@ -118,26 +118,31 @@ class Coo(ctypes.Structure):
118
118
_index_dtype = index_dtype
119
119
120
120
@classmethod
121
- def from_sps (cls , arr : sps .coo_array ) -> "Coo" :
122
- assert arr .has_canonical_format , "COO must have canonical format"
123
- np_pos = np .array ([0 , arr .size ], dtype = index_dtype .np_dtype )
124
- np_coords = np .stack (arr .coords , axis = 1 , dtype = index_dtype .np_dtype )
121
+ def from_sps (cls , arr : sps .coo_array | np .ndarray ) -> "Coo" :
122
+ if isinstance (arr , sps .coo_array ):
123
+ assert arr .has_canonical_format , "COO must have canonical format"
124
+ np_pos = np .array ([0 , arr .size ], dtype = index_dtype .np_dtype )
125
+ np_coords = np .stack (arr .coords , axis = 1 , dtype = index_dtype .np_dtype )
126
+ np_data = arr .data
127
+ else :
128
+ assert len (arr ) == 3 , "COO must be comprised of three arrays"
129
+ np_pos , np_coords , np_data = arr
130
+
125
131
pos = numpy_to_ranked_memref (np_pos )
126
132
coords = numpy_to_ranked_memref (np_coords )
127
- data = numpy_to_ranked_memref (arr .data )
128
-
133
+ data = numpy_to_ranked_memref (np_data )
129
134
coo_instance = cls (pos = pos , coords = coords , data = data )
130
135
_take_owneship (coo_instance , np_pos )
131
136
_take_owneship (coo_instance , np_coords )
132
- _take_owneship (coo_instance , arr )
137
+ _take_owneship (coo_instance , np_data )
133
138
134
139
return coo_instance
135
140
136
- def to_sps (self , shape : tuple [int , ...]) -> sps .coo_array :
141
+ def to_sps (self , shape : tuple [int , ...]) -> sps .coo_array | list [ np . ndarray ] :
137
142
pos = ranked_memref_to_numpy (self .pos )
138
143
coords = ranked_memref_to_numpy (self .coords )[pos [0 ] : pos [1 ]]
139
144
data = ranked_memref_to_numpy (self .data )
140
- return sps .coo_array ((data , coords .T ), shape = shape )
145
+ return sps .coo_array ((data , coords .T ), shape = shape ) if len ( shape ) == 2 else RefableList ([ pos , coords , data ])
141
146
142
147
def to_module_arg (self ) -> list :
143
148
return [
@@ -159,8 +164,13 @@ def get_tensor_definition(cls, shape: tuple[int, ...]) -> ir.RankedTensorType:
159
164
compressed_lvl = sparse_tensor .EncodingAttr .build_level_type (
160
165
sparse_tensor .LevelFormat .compressed , [sparse_tensor .LevelProperty .non_unique ]
161
166
)
162
- levels = (compressed_lvl , sparse_tensor .LevelFormat .singleton )
163
- ordering = ir .AffineMap .get_permutation ([0 , 1 ])
167
+ mid_singleton_lvls = [
168
+ sparse_tensor .EncodingAttr .build_level_type (
169
+ sparse_tensor .LevelFormat .singleton , [sparse_tensor .LevelProperty .non_unique ]
170
+ )
171
+ ] * (len (shape ) - 2 )
172
+ levels = (compressed_lvl , * mid_singleton_lvls , sparse_tensor .LevelFormat .singleton )
173
+ ordering = ir .AffineMap .get_permutation ([* range (len (shape ))])
164
174
encoding = sparse_tensor .EncodingAttr .get (levels , ordering , ordering , index_width , index_width )
165
175
return ir .RankedTensorType .get (list (shape ), values_dtype , encoding )
166
176
@@ -191,10 +201,7 @@ def from_sps(cls, arrs: list[np.ndarray]) -> "Csf":
191
201
return csf_instance
192
202
193
203
def to_sps (self , shape : tuple [int , ...]) -> list [np .ndarray ]:
194
- class List (list ):
195
- pass
196
-
197
- return List (ranked_memref_to_numpy (field ) for field in self .get__fields_ ())
204
+ return RefableList (ranked_memref_to_numpy (field ) for field in self .get__fields_ ())
198
205
199
206
def to_module_arg (self ) -> list :
200
207
return [ctypes .pointer (ctypes .pointer (field )) for field in self .get__fields_ ()]
@@ -310,20 +317,20 @@ def __init__(
310
317
311
318
if obj .format in ("csr" , "csc" ):
312
319
order = "r" if obj .format == "csr" else "c"
313
- index_dtype = asdtype (obj .indptr .dtype )
314
- self ._format_class = get_csx_class (self ._values_dtype , index_dtype , order )
320
+ self . _index_dtype = asdtype (obj .indptr .dtype )
321
+ self ._format_class = get_csx_class (self ._values_dtype , self . _index_dtype , order )
315
322
self ._obj = self ._format_class .from_sps (obj )
316
323
elif obj .format == "coo" :
317
- index_dtype = asdtype (obj .coords [0 ].dtype )
318
- self ._format_class = get_coo_class (self ._values_dtype , index_dtype )
324
+ self . _index_dtype = asdtype (obj .coords [0 ].dtype )
325
+ self ._format_class = get_coo_class (self ._values_dtype , self . _index_dtype )
319
326
self ._obj = self ._format_class .from_sps (obj )
320
327
else :
321
328
raise Exception (f"{ obj .format } SciPy format not supported." )
322
329
323
330
elif _is_numpy_obj (obj ):
324
331
self ._owns_memory = False
325
- index_dtype = asdtype (np .intp )
326
- self ._format_class = get_dense_class (self ._values_dtype , index_dtype )
332
+ self . _index_dtype = asdtype (np .intp )
333
+ self ._format_class = get_dense_class (self ._values_dtype , self . _index_dtype )
327
334
self ._obj = self ._format_class .from_sps (obj )
328
335
329
336
elif _is_mlir_obj (obj ):
@@ -332,11 +339,13 @@ def __init__(
332
339
self ._obj = obj
333
340
334
341
elif format is not None :
335
- if format == "csf" :
342
+ if format in ["csf" , "coo" ]:
343
+ fn_format_class = get_csf_class if format == "csf" else get_coo_class
336
344
self ._owns_memory = False
337
- index_dtype = asdtype (np .intp )
338
- self ._format_class = get_csf_class (self ._values_dtype , index_dtype )
345
+ self . _index_dtype = asdtype (np .intp )
346
+ self ._format_class = fn_format_class (self ._values_dtype , self . _index_dtype )
339
347
self ._obj = self ._format_class .from_sps (obj )
348
+
340
349
else :
341
350
raise Exception (f"Format { format } not supported." )
342
351
0 commit comments