55from mlir import ir
66from mlir .dialects import arith , complex , func , linalg , sparse_tensor , tensor
77
8+ import numpy as np
9+
810from ._array import Array
911from ._common import fn_cache
1012from ._core import CWD , DEBUG , MLIR_C_RUNNER_UTILS , ctx , pm
1113from ._dtypes import DType , IeeeComplexFloatingDType , IeeeRealFloatingDType , IntegerDType
14+ from .levels import _determine_format
1215
1316
1417@fn_cache
@@ -17,7 +20,6 @@ def get_add_module(
1720 b_tensor_type : ir .RankedTensorType ,
1821 out_tensor_type : ir .RankedTensorType ,
1922 dtype : DType ,
20- rank : int ,
2123) -> ir .Module :
2224 with ir .Location .unknown (ctx ):
2325 module = ir .Module .create ()
@@ -31,7 +33,7 @@ def get_add_module(
3133 raise RuntimeError (f"Can not add { dtype = } ." )
3234
3335 dtype = dtype ._get_mlir_type ()
34- ordering = ir . AffineMap . get_permutation ( range ( rank ))
36+ max_rank = out_tensor_type . rank
3537
3638 with ir .InsertionPoint (module .body ):
3739
@@ -42,8 +44,13 @@ def add(a, b):
4244 [out_tensor_type ],
4345 [a , b ],
4446 [out ],
45- ir .ArrayAttr .get ([ir .AffineMapAttr .get (p ) for p in (ordering ,) * 3 ]),
46- ir .ArrayAttr .get ([ir .Attribute .parse ("#linalg.iterator_type<parallel>" )] * rank ),
47+ ir .ArrayAttr .get (
48+ [
49+ ir .AffineMapAttr .get (ir .AffineMap .get_minor_identity (max_rank , t .rank ))
50+ for t in (a_tensor_type , b_tensor_type , out_tensor_type )
51+ ]
52+ ),
53+ ir .ArrayAttr .get ([ir .Attribute .parse ("#linalg.iterator_type<parallel>" )] * out_tensor_type .rank ),
4754 )
4855 block = generic_op .regions [0 ].blocks .append (dtype , dtype , dtype )
4956 with ir .InsertionPoint (block ):
@@ -127,17 +134,16 @@ def broadcast_to(in_tensor):
127134
128135
129136def add (x1 : Array , x2 : Array ) -> Array :
130- ret_storage_format = x1 .format
137+ ret_storage_format = _determine_format ( x1 .format , x2 . format , dtype = x1 . dtype , union = True )
131138 ret_storage = ret_storage_format ._get_ctypes_type (owns_memory = True )()
132- out_tensor_type = ret_storage_format ._get_mlir_type (shape = x1 .shape )
139+ out_tensor_type = ret_storage_format ._get_mlir_type (shape = np . broadcast_shapes ( x1 .shape , x2 . shape ) )
133140
134141 # TODO: Decide what will be the output tensor_type
135142 add_module = get_add_module (
136143 x1 ._get_mlir_type (),
137144 x2 ._get_mlir_type (),
138145 out_tensor_type = out_tensor_type ,
139146 dtype = x1 .dtype ,
140- rank = x1 .ndim ,
141147 )
142148 add_module .invoke (
143149 "add" ,
0 commit comments