@@ -108,11 +108,11 @@ def get_tensor_definition(cls, shape: tuple[int, ...]) -> ir.RankedTensorType:
108108
109109
110110@fn_cache
111- def get_coo_class (values_dtype : type [DType ], index_dtype : type [DType ]) -> type [ctypes .Structure ]:
111+ def get_coo_class (values_dtype : type [DType ], index_dtype : type [DType ], * , rank : int = 2 ) -> type [ctypes .Structure ]:
112112 class Coo (ctypes .Structure ):
113113 _fields_ = [
114114 ("pos" , get_nd_memref_descr (1 , index_dtype )),
115- ( "coords " , get_nd_memref_descr (2 , index_dtype )),
115+ * [( f"coords_ { i } " , get_nd_memref_descr (1 , index_dtype )) for i in range ( rank )] ,
116116 ("data" , get_nd_memref_descr (1 , values_dtype )),
117117 ]
118118 dtype = values_dtype
@@ -124,42 +124,46 @@ def from_sps(cls, arr: sps.coo_array | Iterable[np.ndarray]) -> "Coo":
124124 if not arr .has_canonical_format :
125125 raise Exception ("COO must have canonical format" )
126126 np_pos = np .array ([0 , arr .size ], dtype = index_dtype .np_dtype )
127- np_coords = np .stack ( arr . coords , axis = 1 , dtype = index_dtype .np_dtype )
127+ np_coords = [ np .array ( coord , dtype = index_dtype .np_dtype ) for coord in arr . coords ]
128128 np_data = arr .data
129129 else :
130130 if len (arr ) != 3 :
131131 raise Exception ("COO must be comprised of three arrays" )
132132 np_pos , np_coords , np_data = arr
133133
134134 pos = numpy_to_ranked_memref (np_pos )
135- coords = numpy_to_ranked_memref (np_coords )
135+ coords = [ numpy_to_ranked_memref (coord ) for coord in np_coords ]
136136 data = numpy_to_ranked_memref (np_data )
137- coo_instance = cls (pos = pos , coords = coords , data = data )
137+ coo_instance = cls (pos , * ( coords + [ data ]) )
138138 _take_owneship (coo_instance , np_pos )
139- _take_owneship (coo_instance , np_coords )
139+ for coord in np_coords :
140+ _take_owneship (coo_instance , coord )
140141 _take_owneship (coo_instance , np_data )
141142
142143 return coo_instance
143144
144145 def to_sps (self , shape : tuple [int , ...]) -> sps .coo_array | list [np .ndarray ]:
145146 pos = ranked_memref_to_numpy (self .pos )
146- coords = ranked_memref_to_numpy (self . coords )[ pos [ 0 ] : pos [ 1 ] ]
147+ coords = [ ranked_memref_to_numpy (coord ) for coord in self . get_coord_list () ]
147148 data = ranked_memref_to_numpy (self .data )
148149 return (
149- sps .coo_array ((data , coords . T ), shape = shape )
150+ sps .coo_array ((data , np . stack ( coords , axis = 0 , dtype = index_dtype . np_dtype ) ), shape = shape )
150151 if len (shape ) == 2
151152 else PackedArgumentTuple ((pos , coords , data ))
152153 )
153154
154155 def to_module_arg (self ) -> list :
155156 return [
156157 ctypes .pointer (ctypes .pointer (self .pos )),
157- ctypes .pointer (ctypes .pointer (self .coords )) ,
158+ * [ ctypes .pointer (ctypes .pointer (coord )) for coord in self .get_coord_list ()] ,
158159 ctypes .pointer (ctypes .pointer (self .data )),
159160 ]
160161
161162 def get__fields_ (self ) -> list :
162- return [self .pos , self .coords , self .data ]
163+ return [self .pos , * self .get_coord_list (), self .data ]
164+
165+ def get_coord_list (self ) -> list :
166+ return [getattr (self , f"coords_{ i } " ) for i in range (rank )]
163167
164168 @classmethod
165169 @fn_cache
@@ -173,10 +177,14 @@ def get_tensor_definition(cls, shape: tuple[int, ...]) -> ir.RankedTensorType:
173177 )
174178 mid_singleton_lvls = [
175179 sparse_tensor .EncodingAttr .build_level_type (
176- sparse_tensor .LevelFormat .singleton , [sparse_tensor .LevelProperty .non_unique ]
180+ sparse_tensor .LevelFormat .singleton ,
181+ [sparse_tensor .LevelProperty .non_unique , sparse_tensor .LevelProperty .soa ],
177182 )
178183 ] * (len (shape ) - 2 )
179- levels = (compressed_lvl , * mid_singleton_lvls , sparse_tensor .LevelFormat .singleton )
184+ last_singleton_lvl = sparse_tensor .EncodingAttr .build_level_type (
185+ sparse_tensor .LevelFormat .singleton , [sparse_tensor .LevelProperty .soa ]
186+ )
187+ levels = (compressed_lvl , * mid_singleton_lvls , last_singleton_lvl )
180188 ordering = ir .AffineMap .get_permutation ([* range (len (shape ))])
181189 encoding = sparse_tensor .EncodingAttr .get (levels , ordering , ordering , index_width , index_width )
182190 return ir .RankedTensorType .get (list (shape ), values_dtype , encoding )
@@ -320,6 +328,7 @@ def __init__(
320328 self ._values_dtype = dtype if dtype is not None else asdtype (obj .dtype )
321329
322330 if _is_scipy_sparse_obj (obj ):
331+ self .format = obj .format
323332 self ._owns_memory = False
324333
325334 if obj .format in ("csr" , "csc" ):
@@ -335,22 +344,26 @@ def __init__(
335344 raise Exception (f"{ obj .format } SciPy format not supported." )
336345
337346 elif _is_numpy_obj (obj ):
347+ self .format = "dense"
338348 self ._owns_memory = False
339349 self ._index_dtype = asdtype (np .intp )
340350 self ._format_class = get_dense_class (self ._values_dtype , self ._index_dtype )
341351 self ._obj = self ._format_class .from_sps (obj )
342352
343353 elif _is_mlir_obj (obj ):
354+ self .format = "custom"
344355 self ._owns_memory = True
345356 self ._format_class = type (obj )
346357 self ._obj = obj
347358
348359 elif format is not None :
360+ self .format = format
349361 if format in ["csf" , "coo" ]:
350362 fn_format_class = get_csf_class if format == "csf" else get_coo_class
363+ kwargs = {} if format == "csf" else {"rank" : len (self .shape )}
351364 self ._owns_memory = False
352365 self ._index_dtype = asdtype (np .intp )
353- self ._format_class = fn_format_class (self ._values_dtype , self ._index_dtype )
366+ self ._format_class = fn_format_class (self ._values_dtype , self ._index_dtype , ** kwargs )
354367 self ._obj = self ._format_class .from_sps (obj )
355368
356369 else :
0 commit comments