1
1
import ctypes
2
- import ctypes .util
3
2
import functools
4
3
import weakref
5
4
@@ -61,7 +60,7 @@ def get_module(shape: tuple[int], values_dtype: DType, index_dtype: DType):
61
60
values_dtype = values_dtype .get_mlir_type ()
62
61
index_dtype = index_dtype .get_mlir_type ()
63
62
index_width = getattr (index_dtype , "width" , 0 )
64
- levels = (sparse_tensor .LevelType .dense , sparse_tensor .LevelType .dense )
63
+ levels = (sparse_tensor .LevelFormat .dense , sparse_tensor .LevelFormat .dense )
65
64
ordering = ir .AffineMap .get_permutation ([0 , 1 ])
66
65
encoding = sparse_tensor .EncodingAttr .get (levels , ordering , ordering , index_width , index_width )
67
66
dense_shaped = ir .RankedTensorType .get (list (shape ), values_dtype , encoding )
@@ -71,19 +70,19 @@ def get_module(shape: tuple[int], values_dtype: DType, index_dtype: DType):
71
70
72
71
@func .FuncOp .from_py_func (tensor_1d )
73
72
def assemble (data ):
74
- return sparse_tensor .assemble (dense_shaped , data , [] )
73
+ return sparse_tensor .assemble (dense_shaped , [], data )
75
74
76
75
@func .FuncOp .from_py_func (dense_shaped )
77
76
def disassemble (tensor_shaped ):
78
77
data = tensor .EmptyOp ([arith .constant (ir .IndexType .get (), 0 )], values_dtype )
79
78
data , data_len = sparse_tensor .disassemble (
79
+ [],
80
80
tensor_1d ,
81
81
[],
82
82
index_dtype ,
83
- [],
84
83
tensor_shaped ,
85
- data ,
86
84
[],
85
+ data ,
87
86
)
88
87
shape_x = arith .constant (index_dtype , shape [0 ])
89
88
shape_y = arith .constant (index_dtype , shape [1 ])
@@ -154,7 +153,7 @@ def get_module(shape: tuple[int], values_dtype: type[DType], index_dtype: type[D
154
153
values_dtype = values_dtype .get_mlir_type ()
155
154
index_dtype = index_dtype .get_mlir_type ()
156
155
index_width = getattr (index_dtype , "width" , 0 )
157
- levels = (sparse_tensor .LevelType .dense , sparse_tensor .LevelType .compressed )
156
+ levels = (sparse_tensor .LevelFormat .dense , sparse_tensor .LevelFormat .compressed )
158
157
ordering = ir .AffineMap .get_permutation ([0 , 1 ])
159
158
encoding = sparse_tensor .EncodingAttr .get (levels , ordering , ordering , index_width , index_width )
160
159
csr_shaped = ir .RankedTensorType .get (list (shape ), values_dtype , encoding )
@@ -166,25 +165,25 @@ def get_module(shape: tuple[int], values_dtype: type[DType], index_dtype: type[D
166
165
167
166
@func .FuncOp .from_py_func (tensor_1d_index , tensor_1d_index , tensor_1d_values )
168
167
def assemble (pos , crd , data ):
169
- return sparse_tensor .assemble (csr_shaped , data , (pos , crd ))
168
+ return sparse_tensor .assemble (csr_shaped , (pos , crd ), data )
170
169
171
170
@func .FuncOp .from_py_func (csr_shaped )
172
171
def disassemble (tensor_shaped ):
173
172
pos = tensor .EmptyOp ([arith .constant (ir .IndexType .get (), 0 )], index_dtype )
174
173
crd = tensor .EmptyOp ([arith .constant (ir .IndexType .get (), 0 )], index_dtype )
175
174
data = tensor .EmptyOp ([arith .constant (ir .IndexType .get (), 0 )], values_dtype )
176
- data , pos , crd , data_len , pos_len , crd_len = sparse_tensor .disassemble (
177
- tensor_1d_values ,
175
+ pos , crd , data , pos_len , crd_len , data_len = sparse_tensor .disassemble (
178
176
(tensor_1d_index , tensor_1d_index ),
179
- index_dtype ,
177
+ tensor_1d_values ,
180
178
(index_dtype , index_dtype ),
179
+ index_dtype ,
181
180
tensor_shaped ,
182
- data ,
183
181
(pos , crd ),
182
+ data ,
184
183
)
185
184
shape_x = arith .constant (index_dtype , shape [0 ])
186
185
shape_y = arith .constant (index_dtype , shape [1 ])
187
- return data , pos , crd , data_len , pos_len , crd_len , shape_x , shape_y
186
+ return pos , crd , data , pos_len , crd_len , data_len , shape_x , shape_y
188
187
189
188
@func .FuncOp .from_py_func (csr_shaped )
190
189
def free_tensor (tensor_shaped ):
@@ -219,12 +218,12 @@ def assemble(cls, module: ir.Module, arr: sps.csr_array) -> ctypes.c_void_p:
219
218
def disassemble (cls , module : ir .Module , ptr : ctypes .c_void_p , dtype : type [DType ]) -> sps .csr_array :
220
219
class Csr (ctypes .Structure ):
221
220
_fields_ = [
222
- ("data" , make_memref_ctype (dtype , 1 )),
223
221
("pos" , make_memref_ctype (Index , 1 )),
224
222
("crd" , make_memref_ctype (Index , 1 )),
225
- ("data_len " , np . ctypeslib . c_intp ),
223
+ ("data " , make_memref_ctype ( dtype , 1 ) ),
226
224
("pos_len" , np .ctypeslib .c_intp ),
227
225
("crd_len" , np .ctypeslib .c_intp ),
226
+ ("data_len" , np .ctypeslib .c_intp ),
228
227
("shape_x" , np .ctypeslib .c_intp ),
229
228
("shape_y" , np .ctypeslib .c_intp ),
230
229
]
0 commit comments