1
1
import ctypes
2
+ import math
2
3
3
4
import mlir_finch .execution_engine
4
5
import mlir_finch .passmanager
5
6
from mlir_finch import ir
6
7
from mlir_finch .dialects import arith , complex , func , linalg , sparse_tensor , tensor
7
8
9
+ import numpy as np
10
+
8
11
from ._array import Array
9
- from ._common import fn_cache
10
- from ._core import CWD , DEBUG , SHARED_LIBS , ctx , pm
12
+ from ._common import as_shape , fn_cache
13
+ from ._core import CWD , DEBUG , OPT_LEVEL , SHARED_LIBS , ctx , pm
11
14
from ._dtypes import DType , IeeeComplexFloatingDType , IeeeRealFloatingDType , IntegerDType
15
+ from .levels import StorageFormat , _determine_format
12
16
13
17
14
18
@fn_cache
@@ -17,7 +21,6 @@ def get_add_module(
17
21
b_tensor_type : ir .RankedTensorType ,
18
22
out_tensor_type : ir .RankedTensorType ,
19
23
dtype : DType ,
20
- rank : int ,
21
24
) -> ir .Module :
22
25
with ir .Location .unknown (ctx ):
23
26
module = ir .Module .create ()
@@ -31,7 +34,7 @@ def get_add_module(
31
34
raise RuntimeError (f"Can not add { dtype = } ." )
32
35
33
36
dtype = dtype ._get_mlir_type ()
34
- ordering = ir . AffineMap . get_permutation ( range ( rank ))
37
+ max_rank = out_tensor_type . rank
35
38
36
39
with ir .InsertionPoint (module .body ):
37
40
@@ -42,8 +45,13 @@ def add(a, b):
42
45
[out_tensor_type ],
43
46
[a , b ],
44
47
[out ],
45
- ir .ArrayAttr .get ([ir .AffineMapAttr .get (p ) for p in (ordering ,) * 3 ]),
46
- ir .ArrayAttr .get ([ir .Attribute .parse ("#linalg.iterator_type<parallel>" )] * rank ),
48
+ ir .ArrayAttr .get (
49
+ [
50
+ ir .AffineMapAttr .get (ir .AffineMap .get_minor_identity (max_rank , t .rank ))
51
+ for t in (a_tensor_type , b_tensor_type , out_tensor_type )
52
+ ]
53
+ ),
54
+ ir .ArrayAttr .get ([ir .Attribute .parse ("#linalg.iterator_type<parallel>" )] * max_rank ),
47
55
)
48
56
block = generic_op .regions [0 ].blocks .append (dtype , dtype , dtype )
49
57
with ir .InsertionPoint (block ):
@@ -72,7 +80,7 @@ def add(a, b):
72
80
if DEBUG :
73
81
(CWD / "add_module_opt.mlir" ).write_text (str (module ))
74
82
75
- return mlir_finch .execution_engine .ExecutionEngine (module , opt_level = 2 , shared_libs = SHARED_LIBS )
83
+ return mlir_finch .execution_engine .ExecutionEngine (module , opt_level = OPT_LEVEL , shared_libs = SHARED_LIBS )
76
84
77
85
78
86
@fn_cache
@@ -97,7 +105,7 @@ def reshape(a, shape):
97
105
if DEBUG :
98
106
(CWD / "reshape_module_opt.mlir" ).write_text (str (module ))
99
107
100
- return mlir_finch .execution_engine .ExecutionEngine (module , opt_level = 2 , shared_libs = SHARED_LIBS )
108
+ return mlir_finch .execution_engine .ExecutionEngine (module , opt_level = OPT_LEVEL , shared_libs = SHARED_LIBS )
101
109
102
110
103
111
@fn_cache
@@ -125,26 +133,94 @@ def broadcast_to(in_tensor):
125
133
if DEBUG :
126
134
(CWD / "broadcast_to_module_opt.mlir" ).write_text (str (module ))
127
135
128
- return mlir_finch .execution_engine .ExecutionEngine (module , opt_level = 2 , shared_libs = SHARED_LIBS )
136
+ return mlir_finch .execution_engine .ExecutionEngine (module , opt_level = OPT_LEVEL , shared_libs = SHARED_LIBS )
137
+
138
+
139
+ @fn_cache
140
+ def get_convert_module (
141
+ in_tensor_type : ir .RankedTensorType ,
142
+ out_tensor_type : ir .RankedTensorType ,
143
+ ):
144
+ with ir .Location .unknown (ctx ):
145
+ module = ir .Module .create ()
146
+
147
+ with ir .InsertionPoint (module .body ):
129
148
149
+ @func .FuncOp .from_py_func (in_tensor_type )
150
+ def convert (in_tensor ):
151
+ return sparse_tensor .convert (out_tensor_type , in_tensor )
130
152
131
- def add (x1 : Array , x2 : Array ) -> Array :
132
- ret_storage_format = x1 .format
153
+ convert .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
154
+ if DEBUG :
155
+ (CWD / "convert_module.mlir" ).write_text (str (module ))
156
+ pm .run (module .operation )
157
+ if DEBUG :
158
+ (CWD / "convert_module.mlir" ).write_text (str (module ))
159
+
160
+ return mlir_finch .execution_engine .ExecutionEngine (module , opt_level = OPT_LEVEL , shared_libs = SHARED_LIBS )
161
+
162
+
163
+ def add (x1 : Array , x2 : Array , / ) -> Array :
164
+ # TODO: Determine output format via autoscheduler
165
+ ret_storage_format = _determine_format (x1 .format , x2 .format , dtype = x1 .dtype , union = True )
133
166
ret_storage = ret_storage_format ._get_ctypes_type (owns_memory = True )()
134
- out_tensor_type = ret_storage_format ._get_mlir_type (shape = x1 .shape )
167
+ out_tensor_type = ret_storage_format ._get_mlir_type (shape = np . broadcast_shapes ( x1 .shape , x2 . shape ) )
135
168
136
- # TODO: Decide what will be the output tensor_type
137
169
add_module = get_add_module (
138
170
x1 ._get_mlir_type (),
139
171
x2 ._get_mlir_type (),
140
172
out_tensor_type = out_tensor_type ,
141
173
dtype = x1 .dtype ,
142
- rank = x1 .ndim ,
143
174
)
144
175
add_module .invoke (
145
176
"add" ,
146
177
ctypes .pointer (ctypes .pointer (ret_storage )),
147
178
* x1 ._to_module_arg (),
148
179
* x2 ._to_module_arg (),
149
180
)
150
- return Array (storage = ret_storage , shape = out_tensor_type .shape )
181
+ return Array (storage = ret_storage , shape = tuple (out_tensor_type .shape ))
182
+
183
+
184
+ def asformat (x : Array , / , format : StorageFormat ) -> Array :
185
+ if x .format == format :
186
+ return x
187
+
188
+ out_tensor_type = format ._get_mlir_type (shape = x .shape )
189
+ ret_storage = format ._get_ctypes_type (owns_memory = True )()
190
+
191
+ convert_module = get_convert_module (
192
+ x ._get_mlir_type (),
193
+ out_tensor_type ,
194
+ )
195
+
196
+ convert_module .invoke (
197
+ "convert" ,
198
+ ctypes .pointer (ctypes .pointer (ret_storage )),
199
+ * x ._to_module_arg (),
200
+ )
201
+
202
+ return Array (storage = ret_storage , shape = x .shape )
203
+
204
+
205
+ def reshape (x : Array , / , shape : tuple [int , ...]) -> Array :
206
+ from ._conversions import _from_numpy
207
+
208
+ shape = as_shape (shape )
209
+ if math .prod (x .shape ) != math .prod (shape ):
210
+ raise ValueError (f"`math.prod(x.shape) != math.prod(shape)`, { x .shape = } , { shape = } " )
211
+
212
+ ret_storage_format = _determine_format (x .format , dtype = x .dtype , union = len (shape ) > x .ndim , out_ndim = len (shape ))
213
+ shape_array = _from_numpy (np .asarray (shape , dtype = np .uint64 ))
214
+ out_tensor_type = ret_storage_format ._get_mlir_type (shape = shape )
215
+ ret_storage = ret_storage_format ._get_ctypes_type (owns_memory = True )()
216
+
217
+ reshape_module = get_reshape_module (x ._get_mlir_type (), shape_array ._get_mlir_type (), out_tensor_type )
218
+
219
+ reshape_module .invoke (
220
+ "reshape" ,
221
+ ctypes .pointer (ctypes .pointer (ret_storage )),
222
+ * x ._to_module_arg (),
223
+ * shape_array ._to_module_arg (),
224
+ )
225
+
226
+ return Array (storage = ret_storage , shape = shape )
0 commit comments