1
1
import ctypes
2
+ from collections .abc import Iterable
2
3
from typing import Any
3
4
4
5
import mlir .runtime as rt
8
9
import numpy as np
9
10
import scipy .sparse as sps
10
11
11
- from ._common import RefableList , _hold_self_ref_in_ret , _take_owneship , fn_cache
12
+ from ._common import PackedArgumentTuple , _hold_self_ref_in_ret , _take_owneship , fn_cache
12
13
from ._core import ctx , libc
13
14
from ._dtypes import DType , asdtype
14
15
@@ -118,14 +119,16 @@ class Coo(ctypes.Structure):
118
119
_index_dtype = index_dtype
119
120
120
121
@classmethod
121
- def from_sps (cls , arr : sps .coo_array | np .ndarray ) -> "Coo" :
122
+ def from_sps (cls , arr : sps .coo_array | Iterable [ np .ndarray ] ) -> "Coo" :
122
123
if isinstance (arr , sps .coo_array ):
123
- assert arr .has_canonical_format , "COO must have canonical format"
124
+ if not arr .has_canonical_format :
125
+ raise Exception ("COO must have canonical format" )
124
126
np_pos = np .array ([0 , arr .size ], dtype = index_dtype .np_dtype )
125
127
np_coords = np .stack (arr .coords , axis = 1 , dtype = index_dtype .np_dtype )
126
128
np_data = arr .data
127
129
else :
128
- assert len (arr ) == 3 , "COO must be comprised of three arrays"
130
+ if len (arr ) != 3 :
131
+ raise Exception ("COO must be comprised of three arrays" )
129
132
np_pos , np_coords , np_data = arr
130
133
131
134
pos = numpy_to_ranked_memref (np_pos )
@@ -142,7 +145,11 @@ def to_sps(self, shape: tuple[int, ...]) -> sps.coo_array | list[np.ndarray]:
142
145
pos = ranked_memref_to_numpy (self .pos )
143
146
coords = ranked_memref_to_numpy (self .coords )[pos [0 ] : pos [1 ]]
144
147
data = ranked_memref_to_numpy (self .data )
145
- return sps .coo_array ((data , coords .T ), shape = shape ) if len (shape ) == 2 else RefableList ([pos , coords , data ])
148
+ return (
149
+ sps .coo_array ((data , coords .T ), shape = shape )
150
+ if len (shape ) == 2
151
+ else PackedArgumentTuple ((pos , coords , data ))
152
+ )
146
153
147
154
def to_module_arg (self ) -> list :
148
155
return [
@@ -201,7 +208,7 @@ def from_sps(cls, arrs: list[np.ndarray]) -> "Csf":
201
208
return csf_instance
202
209
203
210
def to_sps (self , shape : tuple [int , ...]) -> list [np .ndarray ]:
204
- return RefableList ( ranked_memref_to_numpy (field ) for field in self .get__fields_ ())
211
+ return PackedArgumentTuple ( tuple ( ranked_memref_to_numpy (field ) for field in self .get__fields_ () ))
205
212
206
213
def to_module_arg (self ) -> list :
207
214
return [ctypes .pointer (ctypes .pointer (field )) for field in self .get__fields_ ()]
0 commit comments