@@ -49,6 +49,55 @@ def free_memref(obj: ctypes.Structure) -> None:
49
49
###########
50
50
51
51
52
+ @fn_cache
53
+ def get_sparse_vector_class (
54
+ values_dtype : type [DType ],
55
+ index_dtype : type [DType ],
56
+ ) -> type [ctypes .Structure ]:
57
+ class SparseVector (ctypes .Structure ):
58
+ _fields_ = [
59
+ ("indptr" , get_nd_memref_descr (1 , index_dtype )),
60
+ ("indices" , get_nd_memref_descr (1 , index_dtype )),
61
+ ("data" , get_nd_memref_descr (1 , values_dtype )),
62
+ ]
63
+ dtype = values_dtype
64
+ _index_dtype = index_dtype
65
+
66
+ @classmethod
67
+ def from_sps (cls , arrs : list [np .ndarray ]) -> "SparseVector" :
68
+ sv_instance = cls (* [numpy_to_ranked_memref (arr ) for arr in arrs ])
69
+ for arr in arrs :
70
+ _take_owneship (sv_instance , arr )
71
+ return sv_instance
72
+
73
+ def to_sps (self , shape : tuple [int , ...]) -> int :
74
+ return PackedArgumentTuple (tuple (ranked_memref_to_numpy (field ) for field in self .get__fields_ ()))
75
+
76
+ def to_module_arg (self ) -> list :
77
+ return [
78
+ ctypes .pointer (ctypes .pointer (self .indptr )),
79
+ ctypes .pointer (ctypes .pointer (self .indices )),
80
+ ctypes .pointer (ctypes .pointer (self .data )),
81
+ ]
82
+
83
+ def get__fields_ (self ) -> list :
84
+ return [self .indptr , self .indices , self .data ]
85
+
86
+ @classmethod
87
+ @fn_cache
88
+ def get_tensor_definition (cls , shape : tuple [int , ...]) -> ir .RankedTensorType :
89
+ with ir .Location .unknown (ctx ):
90
+ values_dtype = cls .dtype .get_mlir_type ()
91
+ index_dtype = cls ._index_dtype .get_mlir_type ()
92
+ index_width = getattr (index_dtype , "width" , 0 )
93
+ levels = (sparse_tensor .LevelFormat .compressed ,)
94
+ ordering = ir .AffineMap .get_permutation ([0 ])
95
+ encoding = sparse_tensor .EncodingAttr .get (levels , ordering , ordering , index_width , index_width )
96
+ return ir .RankedTensorType .get (list (shape ), values_dtype , encoding )
97
+
98
+ return SparseVector
99
+
100
+
52
101
@fn_cache
53
102
def get_csx_class (
54
103
values_dtype : type [DType ],
@@ -302,6 +351,16 @@ def get_csx_scipy_class(order: str) -> type[sps.sparray]:
302
351
raise Exception (f"Invalid order: { order } " )
303
352
304
353
354
+ _constructor_class_dict = {
355
+ "csr" : get_csx_class ,
356
+ "csc" : get_csx_class ,
357
+ "csf" : get_csf_class ,
358
+ "coo" : get_coo_class ,
359
+ "sparse_vector" : get_sparse_vector_class ,
360
+ "dense" : get_dense_class ,
361
+ }
362
+
363
+
305
364
################
306
365
# Tensor class #
307
366
################
@@ -346,8 +405,8 @@ def __init__(
346
405
self ._obj = obj
347
406
348
407
elif format is not None :
349
- if format in ["csf" , "coo" ]:
350
- fn_format_class = get_csf_class if format == "csf" else get_coo_class
408
+ if format in ["csf" , "coo" , "sparse_vector" ]:
409
+ fn_format_class = _constructor_class_dict [ format ]
351
410
self ._owns_memory = False
352
411
self ._index_dtype = asdtype (np .intp )
353
412
self ._format_class = fn_format_class (self ._values_dtype , self ._index_dtype )
0 commit comments