11import ctypes
2+ import math
23
34import mlir_finch .execution_engine
45import mlir_finch .passmanager
56from mlir_finch import ir
67from mlir_finch .dialects import arith , complex , func , linalg , sparse_tensor , tensor
78
9+ import numpy as np
10+
811from ._array import Array
9- from ._common import fn_cache
10- from ._core import CWD , DEBUG , SHARED_LIBS , ctx , pm
12+ from ._common import as_shape , fn_cache
13+ from ._core import CWD , DEBUG , OPT_LEVEL , SHARED_LIBS , ctx , pm
1114from ._dtypes import DType , IeeeComplexFloatingDType , IeeeRealFloatingDType , IntegerDType
15+ from .levels import StorageFormat , _determine_format
1216
1317
1418@fn_cache
@@ -17,7 +21,6 @@ def get_add_module(
1721 b_tensor_type : ir .RankedTensorType ,
1822 out_tensor_type : ir .RankedTensorType ,
1923 dtype : DType ,
20- rank : int ,
2124) -> ir .Module :
2225 with ir .Location .unknown (ctx ):
2326 module = ir .Module .create ()
@@ -31,7 +34,7 @@ def get_add_module(
3134 raise RuntimeError (f"Can not add { dtype = } ." )
3235
3336 dtype = dtype ._get_mlir_type ()
34- ordering = ir . AffineMap . get_permutation ( range ( rank ))
37+ max_rank = out_tensor_type . rank
3538
3639 with ir .InsertionPoint (module .body ):
3740
@@ -42,8 +45,13 @@ def add(a, b):
4245 [out_tensor_type ],
4346 [a , b ],
4447 [out ],
45- ir .ArrayAttr .get ([ir .AffineMapAttr .get (p ) for p in (ordering ,) * 3 ]),
46- ir .ArrayAttr .get ([ir .Attribute .parse ("#linalg.iterator_type<parallel>" )] * rank ),
48+ ir .ArrayAttr .get (
49+ [
50+ ir .AffineMapAttr .get (ir .AffineMap .get_minor_identity (max_rank , t .rank ))
51+ for t in (a_tensor_type , b_tensor_type , out_tensor_type )
52+ ]
53+ ),
54+ ir .ArrayAttr .get ([ir .Attribute .parse ("#linalg.iterator_type<parallel>" )] * max_rank ),
4755 )
4856 block = generic_op .regions [0 ].blocks .append (dtype , dtype , dtype )
4957 with ir .InsertionPoint (block ):
@@ -72,7 +80,7 @@ def add(a, b):
7280 if DEBUG :
7381 (CWD / "add_module_opt.mlir" ).write_text (str (module ))
7482
75- return mlir_finch .execution_engine .ExecutionEngine (module , opt_level = 2 , shared_libs = SHARED_LIBS )
83+ return mlir_finch .execution_engine .ExecutionEngine (module , opt_level = OPT_LEVEL , shared_libs = SHARED_LIBS )
7684
7785
7886@fn_cache
@@ -97,7 +105,7 @@ def reshape(a, shape):
97105 if DEBUG :
98106 (CWD / "reshape_module_opt.mlir" ).write_text (str (module ))
99107
100- return mlir_finch .execution_engine .ExecutionEngine (module , opt_level = 2 , shared_libs = SHARED_LIBS )
108+ return mlir_finch .execution_engine .ExecutionEngine (module , opt_level = OPT_LEVEL , shared_libs = SHARED_LIBS )
101109
102110
103111@fn_cache
@@ -125,26 +133,94 @@ def broadcast_to(in_tensor):
125133 if DEBUG :
126134 (CWD / "broadcast_to_module_opt.mlir" ).write_text (str (module ))
127135
128- return mlir_finch .execution_engine .ExecutionEngine (module , opt_level = 2 , shared_libs = SHARED_LIBS )
136+ return mlir_finch .execution_engine .ExecutionEngine (module , opt_level = OPT_LEVEL , shared_libs = SHARED_LIBS )
137+
138+
139+ @fn_cache
140+ def get_convert_module (
141+ in_tensor_type : ir .RankedTensorType ,
142+ out_tensor_type : ir .RankedTensorType ,
143+ ):
144+ with ir .Location .unknown (ctx ):
145+ module = ir .Module .create ()
146+
147+ with ir .InsertionPoint (module .body ):
129148
149+ @func .FuncOp .from_py_func (in_tensor_type )
150+ def convert (in_tensor ):
151+ return sparse_tensor .convert (out_tensor_type , in_tensor )
130152
131- def add (x1 : Array , x2 : Array ) -> Array :
132- ret_storage_format = x1 .format
153+ convert .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
154+ if DEBUG :
155+ (CWD / "convert_module.mlir" ).write_text (str (module ))
156+ pm .run (module .operation )
157+ if DEBUG :
158+ (CWD / "convert_module.mlir" ).write_text (str (module ))
159+
160+ return mlir_finch .execution_engine .ExecutionEngine (module , opt_level = OPT_LEVEL , shared_libs = SHARED_LIBS )
161+
162+
163+ def add (x1 : Array , x2 : Array , / ) -> Array :
164+ # TODO: Determine output format via autoscheduler
165+ ret_storage_format = _determine_format (x1 .format , x2 .format , dtype = x1 .dtype , union = True )
133166 ret_storage = ret_storage_format ._get_ctypes_type (owns_memory = True )()
134- out_tensor_type = ret_storage_format ._get_mlir_type (shape = x1 .shape )
167+ out_tensor_type = ret_storage_format ._get_mlir_type (shape = np . broadcast_shapes ( x1 .shape , x2 . shape ) )
135168
136- # TODO: Decide what will be the output tensor_type
137169 add_module = get_add_module (
138170 x1 ._get_mlir_type (),
139171 x2 ._get_mlir_type (),
140172 out_tensor_type = out_tensor_type ,
141173 dtype = x1 .dtype ,
142- rank = x1 .ndim ,
143174 )
144175 add_module .invoke (
145176 "add" ,
146177 ctypes .pointer (ctypes .pointer (ret_storage )),
147178 * x1 ._to_module_arg (),
148179 * x2 ._to_module_arg (),
149180 )
150- return Array (storage = ret_storage , shape = out_tensor_type .shape )
181+ return Array (storage = ret_storage , shape = tuple (out_tensor_type .shape ))
182+
183+
184+ def asformat (x : Array , / , format : StorageFormat ) -> Array :
185+ if x .format == format :
186+ return x
187+
188+ out_tensor_type = format ._get_mlir_type (shape = x .shape )
189+ ret_storage = format ._get_ctypes_type (owns_memory = True )()
190+
191+ convert_module = get_convert_module (
192+ x ._get_mlir_type (),
193+ out_tensor_type ,
194+ )
195+
196+ convert_module .invoke (
197+ "convert" ,
198+ ctypes .pointer (ctypes .pointer (ret_storage )),
199+ * x ._to_module_arg (),
200+ )
201+
202+ return Array (storage = ret_storage , shape = x .shape )
203+
204+
205+ def reshape (x : Array , / , shape : tuple [int , ...]) -> Array :
206+ from ._conversions import _from_numpy
207+
208+ shape = as_shape (shape )
209+ if math .prod (x .shape ) != math .prod (shape ):
210+ raise ValueError (f"`math.prod(x.shape) != math.prod(shape)`, { x .shape = } , { shape = } " )
211+
212+ ret_storage_format = _determine_format (x .format , dtype = x .dtype , union = len (shape ) > x .ndim , out_ndim = len (shape ))
213+ shape_array = _from_numpy (np .asarray (shape , dtype = np .uint64 ))
214+ out_tensor_type = ret_storage_format ._get_mlir_type (shape = shape )
215+ ret_storage = ret_storage_format ._get_ctypes_type (owns_memory = True )()
216+
217+ reshape_module = get_reshape_module (x ._get_mlir_type (), shape_array ._get_mlir_type (), out_tensor_type )
218+
219+ reshape_module .invoke (
220+ "reshape" ,
221+ ctypes .pointer (ctypes .pointer (ret_storage )),
222+ * x ._to_module_arg (),
223+ * shape_array ._to_module_arg (),
224+ )
225+
226+ return Array (storage = ret_storage , shape = shape )
0 commit comments