1
1
import ctypes
2
+ from collections .abc import Iterable
2
3
from typing import Any
3
4
4
5
import mlir .runtime as rt
8
9
import numpy as np
9
10
import scipy .sparse as sps
10
11
11
- from ._common import _hold_self_ref_in_ret , _take_owneship , fn_cache
12
+ from ._common import PackedArgumentTuple , _hold_self_ref_in_ret , _take_owneship , fn_cache
12
13
from ._core import ctx , libc
13
14
from ._dtypes import DType , asdtype
14
15
@@ -118,26 +119,37 @@ class Coo(ctypes.Structure):
118
119
_index_dtype = index_dtype
119
120
120
121
@classmethod
121
- def from_sps (cls , arr : sps .coo_array ) -> "Coo" :
122
- assert arr .has_canonical_format , "COO must have canonical format"
123
- np_pos = np .array ([0 , arr .size ], dtype = index_dtype .np_dtype )
124
- np_coords = np .stack (arr .coords , axis = 1 , dtype = index_dtype .np_dtype )
122
+ def from_sps (cls , arr : sps .coo_array | Iterable [np .ndarray ]) -> "Coo" :
123
+ if isinstance (arr , sps .coo_array ):
124
+ if not arr .has_canonical_format :
125
+ raise Exception ("COO must have canonical format" )
126
+ np_pos = np .array ([0 , arr .size ], dtype = index_dtype .np_dtype )
127
+ np_coords = np .stack (arr .coords , axis = 1 , dtype = index_dtype .np_dtype )
128
+ np_data = arr .data
129
+ else :
130
+ if len (arr ) != 3 :
131
+ raise Exception ("COO must be comprised of three arrays" )
132
+ np_pos , np_coords , np_data = arr
133
+
125
134
pos = numpy_to_ranked_memref (np_pos )
126
135
coords = numpy_to_ranked_memref (np_coords )
127
- data = numpy_to_ranked_memref (arr .data )
128
-
136
+ data = numpy_to_ranked_memref (np_data )
129
137
coo_instance = cls (pos = pos , coords = coords , data = data )
130
138
_take_owneship (coo_instance , np_pos )
131
139
_take_owneship (coo_instance , np_coords )
132
- _take_owneship (coo_instance , arr )
140
+ _take_owneship (coo_instance , np_data )
133
141
134
142
return coo_instance
135
143
136
- def to_sps (self , shape : tuple [int , ...]) -> sps .coo_array :
144
+ def to_sps (self , shape : tuple [int , ...]) -> sps .coo_array | list [ np . ndarray ] :
137
145
pos = ranked_memref_to_numpy (self .pos )
138
146
coords = ranked_memref_to_numpy (self .coords )[pos [0 ] : pos [1 ]]
139
147
data = ranked_memref_to_numpy (self .data )
140
- return sps .coo_array ((data , coords .T ), shape = shape )
148
+ return (
149
+ sps .coo_array ((data , coords .T ), shape = shape )
150
+ if len (shape ) == 2
151
+ else PackedArgumentTuple ((pos , coords , data ))
152
+ )
141
153
142
154
def to_module_arg (self ) -> list :
143
155
return [
@@ -159,8 +171,13 @@ def get_tensor_definition(cls, shape: tuple[int, ...]) -> ir.RankedTensorType:
159
171
compressed_lvl = sparse_tensor .EncodingAttr .build_level_type (
160
172
sparse_tensor .LevelFormat .compressed , [sparse_tensor .LevelProperty .non_unique ]
161
173
)
162
- levels = (compressed_lvl , sparse_tensor .LevelFormat .singleton )
163
- ordering = ir .AffineMap .get_permutation ([0 , 1 ])
174
+ mid_singleton_lvls = [
175
+ sparse_tensor .EncodingAttr .build_level_type (
176
+ sparse_tensor .LevelFormat .singleton , [sparse_tensor .LevelProperty .non_unique ]
177
+ )
178
+ ] * (len (shape ) - 2 )
179
+ levels = (compressed_lvl , * mid_singleton_lvls , sparse_tensor .LevelFormat .singleton )
180
+ ordering = ir .AffineMap .get_permutation ([* range (len (shape ))])
164
181
encoding = sparse_tensor .EncodingAttr .get (levels , ordering , ordering , index_width , index_width )
165
182
return ir .RankedTensorType .get (list (shape ), values_dtype , encoding )
166
183
@@ -191,10 +208,7 @@ def from_sps(cls, arrs: list[np.ndarray]) -> "Csf":
191
208
return csf_instance
192
209
193
210
def to_sps (self , shape : tuple [int , ...]) -> list [np .ndarray ]:
194
- class List (list ):
195
- pass
196
-
197
- return List (ranked_memref_to_numpy (field ) for field in self .get__fields_ ())
211
+ return PackedArgumentTuple (tuple (ranked_memref_to_numpy (field ) for field in self .get__fields_ ()))
198
212
199
213
def to_module_arg (self ) -> list :
200
214
return [ctypes .pointer (ctypes .pointer (field )) for field in self .get__fields_ ()]
@@ -310,20 +324,20 @@ def __init__(
310
324
311
325
if obj .format in ("csr" , "csc" ):
312
326
order = "r" if obj .format == "csr" else "c"
313
- index_dtype = asdtype (obj .indptr .dtype )
314
- self ._format_class = get_csx_class (self ._values_dtype , index_dtype , order )
327
+ self . _index_dtype = asdtype (obj .indptr .dtype )
328
+ self ._format_class = get_csx_class (self ._values_dtype , self . _index_dtype , order )
315
329
self ._obj = self ._format_class .from_sps (obj )
316
330
elif obj .format == "coo" :
317
- index_dtype = asdtype (obj .coords [0 ].dtype )
318
- self ._format_class = get_coo_class (self ._values_dtype , index_dtype )
331
+ self . _index_dtype = asdtype (obj .coords [0 ].dtype )
332
+ self ._format_class = get_coo_class (self ._values_dtype , self . _index_dtype )
319
333
self ._obj = self ._format_class .from_sps (obj )
320
334
else :
321
335
raise Exception (f"{ obj .format } SciPy format not supported." )
322
336
323
337
elif _is_numpy_obj (obj ):
324
338
self ._owns_memory = False
325
- index_dtype = asdtype (np .intp )
326
- self ._format_class = get_dense_class (self ._values_dtype , index_dtype )
339
+ self . _index_dtype = asdtype (np .intp )
340
+ self ._format_class = get_dense_class (self ._values_dtype , self . _index_dtype )
327
341
self ._obj = self ._format_class .from_sps (obj )
328
342
329
343
elif _is_mlir_obj (obj ):
@@ -332,11 +346,13 @@ def __init__(
332
346
self ._obj = obj
333
347
334
348
elif format is not None :
335
- if format == "csf" :
349
+ if format in ["csf" , "coo" ]:
350
+ fn_format_class = get_csf_class if format == "csf" else get_coo_class
336
351
self ._owns_memory = False
337
- index_dtype = asdtype (np .intp )
338
- self ._format_class = get_csf_class (self ._values_dtype , index_dtype )
352
+ self . _index_dtype = asdtype (np .intp )
353
+ self ._format_class = fn_format_class (self ._values_dtype , self . _index_dtype )
339
354
self ._obj = self ._format_class .from_sps (obj )
355
+
340
356
else :
341
357
raise Exception (f"Format { format } not supported." )
342
358
0 commit comments