1111from ._common import as_shape , fn_cache
1212from ._core import CWD , DEBUG , OPT_LEVEL , SHARED_LIBS , ctx , pm
1313from ._dtypes import DType , IeeeComplexFloatingDType , IeeeRealFloatingDType , IntegerDType
14- from .levels import _determine_format
14+ from .levels import StorageFormat , _determine_format
1515
1616
1717@fn_cache
@@ -135,7 +135,31 @@ def broadcast_to(in_tensor):
135135 return mlir_finch .execution_engine .ExecutionEngine (module , opt_level = OPT_LEVEL , shared_libs = SHARED_LIBS )
136136
137137
138- def add (x1 : Array , x2 : Array ) -> Array :
138+ @fn_cache
139+ def get_convert_module (
140+ in_tensor_type : ir .RankedTensorType ,
141+ out_tensor_type : ir .RankedTensorType ,
142+ ):
143+ with ir .Location .unknown (ctx ):
144+ module = ir .Module .create ()
145+
146+ with ir .InsertionPoint (module .body ):
147+
148+ @func .FuncOp .from_py_func (in_tensor_type )
149+ def convert (in_tensor ):
150+ return sparse_tensor .convert (out_tensor_type , in_tensor )
151+
152+ convert .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
153+ if DEBUG :
154+ (CWD / "broadcast_to_module.mlir" ).write_text (str (module ))
155+ pm .run (module .operation )
156+ if DEBUG :
157+ (CWD / "broadcast_to_module_opt.mlir" ).write_text (str (module ))
158+
159+ return mlir_finch .execution_engine .ExecutionEngine (module , opt_level = OPT_LEVEL , shared_libs = SHARED_LIBS )
160+
161+
162+ def add (x1 : Array , x2 : Array , / ) -> Array :
139163 ret_storage_format = _determine_format (x1 .format , x2 .format , dtype = x1 .dtype , union = True )
140164 ret_storage = ret_storage_format ._get_ctypes_type (owns_memory = True )()
141165 out_tensor_type = ret_storage_format ._get_mlir_type (shape = np .broadcast_shapes (x1 .shape , x2 .shape ))
@@ -156,6 +180,24 @@ def add(x1: Array, x2: Array) -> Array:
156180 return Array (storage = ret_storage , shape = tuple (out_tensor_type .shape ))
157181
158182
183+ def asformat (x : Array , / , format : StorageFormat ) -> Array :
184+ out_tensor_type = format ._get_mlir_type (shape = x .shape )
185+ ret_storage = format ._get_ctypes_type (owns_memory = True )()
186+
187+ convert_module = get_convert_module (
188+ x ._get_mlir_type (),
189+ out_tensor_type ,
190+ )
191+
192+ convert_module .invoke (
193+ "convert" ,
194+ ctypes .pointer (ctypes .pointer (ret_storage )),
195+ * x ._to_module_arg (),
196+ )
197+
198+ return Array (storage = ret_storage , shape = x .shape )
199+
200+
159201def reshape (x : Array , / , shape : tuple [int , ...]) -> Array :
160202 from ._conversions import _from_numpy
161203
0 commit comments