5
5
import mlir .execution_engine
6
6
import mlir .passmanager
7
7
from mlir import ir
8
+ from mlir import runtime as rt
8
9
from mlir .dialects import arith , bufferization , func , sparse_tensor , tensor
9
10
10
11
import numpy as np
13
14
from ._common import fn_cache
14
15
from ._core import CWD , DEBUG , MLIR_C_RUNNER_UTILS , ctx
15
16
from ._dtypes import DType , Index , asdtype
16
- from ._memref import make_memref_ctype , ranked_memref_from_np
17
17
18
18
19
19
def _hold_self_ref_in_ret (fn ):
@@ -108,7 +108,7 @@ def free_tensor(tensor_shaped):
108
108
@classmethod
109
109
def assemble (cls , module , arr : np .ndarray ) -> ctypes .c_void_p :
110
110
assert arr .ndim == 2
111
- data = ranked_memref_from_np (arr .flatten ())
111
+ data = rt . get_ranked_memref_descriptor (arr .flatten ())
112
112
out = ctypes .c_void_p ()
113
113
module .invoke (
114
114
"assemble" ,
@@ -121,14 +121,14 @@ def assemble(cls, module, arr: np.ndarray) -> ctypes.c_void_p:
121
121
def disassemble (cls , module : ir .Module , ptr : ctypes .c_void_p , dtype : type [DType ]) -> np .ndarray :
122
122
class Dense (ctypes .Structure ):
123
123
_fields_ = [
124
- ("data" , make_memref_ctype ( dtype , 1 )),
124
+ ("data" , rt . make_nd_memref_descriptor ( 1 , dtype . to_ctype () )),
125
125
("data_len" , np .ctypeslib .c_intp ),
126
126
("shape_x" , np .ctypeslib .c_intp ),
127
127
("shape_y" , np .ctypeslib .c_intp ),
128
128
]
129
129
130
130
def to_np (self ) -> np .ndarray :
131
- data = self .data . to_numpy ( )[: self .data_len ]
131
+ data = rt . ranked_memref_to_numpy ([ self .data ] )[: self .data_len ]
132
132
return data .reshape ((self .shape_x , self .shape_y ))
133
133
134
134
arr = Dense ()
@@ -141,8 +141,107 @@ def to_np(self) -> np.ndarray:
141
141
142
142
143
143
class COOFormat :
144
- # TODO: implement
145
- ...
144
+ @fn_cache
145
+ def get_module (shape : tuple [int ], values_dtype : type [DType ], index_dtype : type [DType ]):
146
+ with ir .Location .unknown (ctx ):
147
+ module = ir .Module .create ()
148
+ values_dtype = values_dtype .get_mlir_type ()
149
+ index_dtype = index_dtype .get_mlir_type ()
150
+ index_width = getattr (index_dtype , "width" , 0 )
151
+ compressed_lvl = sparse_tensor .EncodingAttr .build_level_type (
152
+ sparse_tensor .LevelFormat .compressed , [sparse_tensor .LevelProperty .non_unique ]
153
+ )
154
+ levels = (compressed_lvl , sparse_tensor .LevelFormat .singleton )
155
+ ordering = ir .AffineMap .get_permutation ([0 , 1 ])
156
+ encoding = sparse_tensor .EncodingAttr .get (levels , ordering , ordering , index_width , index_width )
157
+ coo_shaped = ir .RankedTensorType .get (list (shape ), values_dtype , encoding )
158
+
159
+ tensor_1d_index = tensor .RankedTensorType .get ([ir .ShapedType .get_dynamic_size ()], index_dtype )
160
+ tensor_2d_index = tensor .RankedTensorType .get ([ir .ShapedType .get_dynamic_size (), len (shape )], index_dtype )
161
+ tensor_1d_values = tensor .RankedTensorType .get ([ir .ShapedType .get_dynamic_size ()], values_dtype )
162
+
163
+ with ir .InsertionPoint (module .body ):
164
+
165
+ @func .FuncOp .from_py_func (tensor_1d_index , tensor_2d_index , tensor_1d_values )
166
+ def assemble (pos , index , values ):
167
+ return sparse_tensor .assemble (coo_shaped , (pos , index ), values )
168
+
169
+ @func .FuncOp .from_py_func (coo_shaped )
170
+ def disassemble (tensor_shaped ):
171
+ nse = sparse_tensor .number_of_entries (tensor_shaped )
172
+ pos = tensor .EmptyOp ([arith .constant (ir .IndexType .get (), 2 )], index_dtype )
173
+ index = tensor .EmptyOp ([nse , 2 ], index_dtype )
174
+ values = tensor .EmptyOp ([nse ], values_dtype )
175
+ pos , index , values , pos_len , index_len , values_len = sparse_tensor .disassemble (
176
+ (tensor_1d_index , tensor_2d_index ),
177
+ tensor_1d_values ,
178
+ (index_dtype , index_dtype ),
179
+ index_dtype ,
180
+ tensor_shaped ,
181
+ (pos , index ),
182
+ values ,
183
+ )
184
+ shape_consts = [arith .constant (index_dtype , s ) for s in shape ]
185
+ return pos , index , values , pos_len , index_len , values_len , * shape_consts
186
+
187
+ @func .FuncOp .from_py_func (coo_shaped )
188
+ def free_tensor (tensor_shaped ):
189
+ bufferization .dealloc_tensor (tensor_shaped )
190
+
191
+ assemble .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
192
+ disassemble .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
193
+ free_tensor .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
194
+ if DEBUG :
195
+ (CWD / "coo_module.mlir" ).write_text (str (module ))
196
+ pm = mlir .passmanager .PassManager .parse ("builtin.module(sparsifier{create-sparse-deallocs=1})" )
197
+ pm .run (module .operation )
198
+ if DEBUG :
199
+ (CWD / "coo_module_opt.mlir" ).write_text (str (module ))
200
+
201
+ module = mlir .execution_engine .ExecutionEngine (module , opt_level = 2 , shared_libs = [MLIR_C_RUNNER_UTILS ])
202
+ return (module , coo_shaped )
203
+
204
+ @classmethod
205
+ def assemble (cls , module : ir .Module , arr : sps .coo_array ) -> ctypes .c_void_p :
206
+ out = ctypes .c_void_p ()
207
+ module .invoke (
208
+ "assemble" ,
209
+ ctypes .pointer (
210
+ ctypes .pointer (rt .get_ranked_memref_descriptor (np .array ([0 , arr .size ], dtype = arr .coords [0 ].dtype )))
211
+ ),
212
+ ctypes .pointer (ctypes .pointer (rt .get_ranked_memref_descriptor (np .stack (arr .coords , axis = 1 )))),
213
+ ctypes .pointer (ctypes .pointer (rt .get_ranked_memref_descriptor (arr .data ))),
214
+ ctypes .pointer (out ),
215
+ )
216
+ return out
217
+
218
+ @classmethod
219
+ def disassemble (cls , module : ir .Module , ptr : ctypes .c_void_p , dtype : type [DType ]) -> sps .coo_array :
220
+ class Coo (ctypes .Structure ):
221
+ _fields_ = [
222
+ ("pos" , rt .make_nd_memref_descriptor (1 , Index .to_ctype ())),
223
+ ("index" , rt .make_nd_memref_descriptor (2 , Index .to_ctype ())),
224
+ ("values" , rt .make_nd_memref_descriptor (1 , dtype .to_ctype ())),
225
+ ("pos_len" , np .ctypeslib .c_intp ),
226
+ ("index_len" , np .ctypeslib .c_intp ),
227
+ ("values_len" , np .ctypeslib .c_intp ),
228
+ ("shape_x" , np .ctypeslib .c_intp ),
229
+ ("shape_y" , np .ctypeslib .c_intp ),
230
+ ]
231
+
232
+ def to_sps (self ) -> sps .coo_array :
233
+ pos = rt .ranked_memref_to_numpy ([self .pos ])[: self .pos_len ]
234
+ index = rt .ranked_memref_to_numpy ([self .index ])[pos [0 ] : pos [1 ]]
235
+ values = rt .ranked_memref_to_numpy ([self .values ])[: self .values_len ]
236
+ return sps .coo_array ((values , index .T ), shape = (self .shape_x , self .shape_y ))
237
+
238
+ arr = Coo ()
239
+ module .invoke (
240
+ "disassemble" ,
241
+ ctypes .pointer (ctypes .pointer (arr )),
242
+ ctypes .pointer (ptr ),
243
+ )
244
+ return arr .to_sps ()
146
245
147
246
148
247
class CSRFormat :
@@ -207,9 +306,9 @@ def assemble(cls, module: ir.Module, arr: sps.csr_array) -> ctypes.c_void_p:
207
306
out = ctypes .c_void_p ()
208
307
module .invoke (
209
308
"assemble" ,
210
- ctypes .pointer (ctypes .pointer (ranked_memref_from_np (arr .indptr ))),
211
- ctypes .pointer (ctypes .pointer (ranked_memref_from_np (arr .indices ))),
212
- ctypes .pointer (ctypes .pointer (ranked_memref_from_np (arr .data ))),
309
+ ctypes .pointer (ctypes .pointer (rt . get_ranked_memref_descriptor (arr .indptr ))),
310
+ ctypes .pointer (ctypes .pointer (rt . get_ranked_memref_descriptor (arr .indices ))),
311
+ ctypes .pointer (ctypes .pointer (rt . get_ranked_memref_descriptor (arr .data ))),
213
312
ctypes .pointer (out ),
214
313
)
215
314
return out
@@ -218,9 +317,9 @@ def assemble(cls, module: ir.Module, arr: sps.csr_array) -> ctypes.c_void_p:
218
317
def disassemble (cls , module : ir .Module , ptr : ctypes .c_void_p , dtype : type [DType ]) -> sps .csr_array :
219
318
class Csr (ctypes .Structure ):
220
319
_fields_ = [
221
- ("pos" , make_memref_ctype ( Index , 1 )),
222
- ("crd" , make_memref_ctype ( Index , 1 )),
223
- ("data" , make_memref_ctype ( dtype , 1 )),
320
+ ("pos" , rt . make_nd_memref_descriptor ( 1 , Index . to_ctype () )),
321
+ ("crd" , rt . make_nd_memref_descriptor ( 1 , Index . to_ctype () )),
322
+ ("data" , rt . make_nd_memref_descriptor ( 1 , dtype . to_ctype () )),
224
323
("pos_len" , np .ctypeslib .c_intp ),
225
324
("crd_len" , np .ctypeslib .c_intp ),
226
325
("data_len" , np .ctypeslib .c_intp ),
@@ -229,9 +328,9 @@ class Csr(ctypes.Structure):
229
328
]
230
329
231
330
def to_sps (self ) -> sps .csr_array :
232
- pos = self .pos . to_numpy ( )[: self .pos_len ]
233
- crd = self .crd . to_numpy ( )[: self .crd_len ]
234
- data = self .data . to_numpy ( )[: self .data_len ]
331
+ pos = rt . ranked_memref_to_numpy ([ self .pos ] )[: self .pos_len ]
332
+ crd = rt . ranked_memref_to_numpy ([ self .crd ] )[: self .crd_len ]
333
+ data = rt . ranked_memref_to_numpy ([ self .data ] )[: self .data_len ]
235
334
return sps .csr_array ((data , crd , pos ), shape = (self .shape_x , self .shape_y ))
236
335
237
336
arr = Csr ()
@@ -257,9 +356,16 @@ def asarray(obj) -> Tensor:
257
356
258
357
# TODO: support other scipy formats
259
358
if _is_scipy_sparse_obj (obj ):
260
- format_class = CSRFormat
261
- # This can be int32 or int64
262
- index_dtype = asdtype (obj .indptr .dtype )
359
+ if obj .format == "csr" :
360
+ format_class = CSRFormat
361
+ # This can be int32 or int64
362
+ index_dtype = asdtype (obj .indptr .dtype )
363
+ elif obj .format == "coo" :
364
+ format_class = COOFormat
365
+ # This can be int32 or int64
366
+ index_dtype = asdtype (obj .coords [0 ].dtype )
367
+ else :
368
+ raise Exception (f"{ obj .format } SciPy format not supported." )
263
369
elif _is_numpy_obj (obj ):
264
370
format_class = DenseFormat
265
371
index_dtype = Index
0 commit comments