44# See https://llvm.org/LICENSE.txt for license information.
55# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
7+ from os import environ
78import sympy
89from typing import Any , Callable , ClassVar , Optional , List , Type , Dict
910from dataclasses import dataclass
2122
2223
2324from ...compiler .ir import (
25+ AffineExpr ,
26+ AffineMap ,
2427 Attribute ,
2528 DenseElementsAttr ,
2629 FloatAttr ,
3437 ShapedType ,
3538 Value ,
3639 VectorType ,
40+ affine_d ,
3741 arith_d ,
3842 func_d ,
3943 gpu_d ,
@@ -198,55 +202,77 @@ def add_emitter_subs(
198202 return dynamics
199203
200204
201- _emulate_ceildiv = False
205+ _emulate_ceildiv = bool (int (environ .get ("WAVE_EMULATE_CEILDIV" , 0 )))
206+ _use_affine_expr = bool (int (environ .get ("WAVE_USE_AFFINE_EXPR" , 1 )))
202207
203208_Rational = namedtuple ("_Rational" , ["numerator" , "denominator" ])
209+ _ApplyExpr = namedtuple ("_ApplyExpr" , ["expr" , "args" ])
204210
205211
206- def gen_sympy_index (dynamics : dict [IndexSymbol , Value ], expr : sympy .Expr ) -> OpResult :
212+ def gen_sympy_index (dynamics : dict [IndexSymbol , Value ], expr : sympy .Expr ) -> Value :
213+ use_affine_expr = _use_affine_expr
207214 stack : list [OpResult ] = []
208215
209216 def _get_ir_value (arg ):
217+ if isinstance (arg , _ApplyExpr ):
218+ args = _broadcast (* arg .args )
219+ expr = arg .expr
220+ expr = AffineMap .get (dim_count = 0 , symbol_count = len (args ), exprs = [expr ])
221+
222+ return affine_d .apply (expr , args )
223+
210224 if not isinstance (arg , (Value , OpResult )):
211225 arg = arg .result
212226
213227 return arg
214228
215229 def _check_vec_scalar (a , b ):
216- if not isinstance (a . type , VectorType ):
230+ if not isinstance (a , VectorType ):
217231 return False
218232
219- if a .type . element_type == b . type :
233+ if a .element_type == b :
220234 return True
221235
222236 return (
223- isinstance (b . type , VectorType )
224- and b .type . shape == [1 ]
225- and a .type . element_type == b . type .element_type
237+ isinstance (b , VectorType )
238+ and b .shape == [1 ]
239+ and a .element_type == b .element_type
226240 )
227241
228- def _broadcast (a , b ):
229- a = _get_ir_value (a )
230- b = _get_ir_value (b )
242+ def _broadcast (* args ):
243+ assert len (args ) > 0
244+ if len (args ) == 1 :
245+ return args
246+
247+ res_args = [_get_ir_value (a ) for a in args ]
248+ res_type = res_args [0 ].type
249+ for arg in res_args [1 :]:
250+ arg_type = arg .type
251+ if arg_type == res_type :
252+ continue
253+
254+ if _check_vec_scalar (res_type , arg_type ):
255+ # broadcast to res_type
256+ continue
257+
258+ if _check_vec_scalar (arg_type , res_type ):
259+ res_type = arg_type
260+ continue
231261
232- if a .type == b .type :
233- return a , b
262+ raise CodegenError (f"Cannot broadcast { res_type } and { arg .type } " )
234263
235- if _check_vec_scalar ( a , b ):
236- if isinstance ( b .type , VectorType ) :
237- b = vector_d . extract ( b , static_position = [ 0 ], dynamic_position = [])
264+ for i , arg in enumerate ( res_args ):
265+ if arg .type == res_type :
266+ continue
238267
239- b = vector_d . splat ( a .type , b )
240- return a , b
268+ if isinstance ( arg .type , VectorType ):
269+ arg = vector_d . extract ( arg , static_position = [ 0 ], dynamic_position = [])
241270
242- if _check_vec_scalar (b , a ):
243- if isinstance (a .type , VectorType ):
244- a = vector_d .extract (a , static_position = [0 ], dynamic_position = [])
271+ res_args [i ] = vector_d .splat (res_type , arg )
245272
246- a = vector_d .splat (b .type , a )
247- return a , b
273+ assert all (arg .type == res_type for arg in res_args )
248274
249- raise CodegenError ( f"Cannot broadcast { a . type } and { b . type } " )
275+ return tuple ( res_args )
250276
251277 def get_const_val (arg ):
252278 if isinstance (arg , OpResult ):
@@ -287,70 +313,125 @@ def addi(lhs, rhs):
287313
288314 return arith_d .addi (lhs , rhs , overflow_flags = overflow_flags )
289315
316+ def op_expr (lhs , rhs , op ):
317+ if isinstance (lhs , _ApplyExpr ):
318+ lhs_args = lhs .args
319+ lhs_expr = lhs .expr
320+ else :
321+ lhs_args = [lhs ]
322+ lhs_expr = AffineExpr .get_symbol (0 )
323+
324+ if isinstance (rhs , _ApplyExpr ):
325+ rhs_args = rhs .args
326+ rhs_expr = rhs .expr
327+ else :
328+ rhs_args = [rhs ]
329+ rhs_expr = AffineExpr .get_symbol (0 )
330+
331+ args = lhs_args + rhs_args
332+ expr = op (lhs_expr , rhs_expr .shift_symbols (len (rhs_args ), len (lhs_args )))
333+ return _ApplyExpr (expr , args )
334+
335+ def check_index_types (* args ):
336+ return all (
337+ isinstance (a , _ApplyExpr ) or isinstance (a .type , IndexType ) for a in args
338+ )
339+
340+ def add_expr (lhs , rhs ):
341+ if not use_affine_expr or not check_index_types (lhs , rhs ):
342+ return addi (* _broadcast (lhs , rhs ))
343+
344+ return op_expr (lhs , rhs , lambda a , b : a + b )
345+
346+ def muli_expr (lhs , rhs ):
347+ if not use_affine_expr or not check_index_types (lhs , rhs ):
348+ return muli (* _broadcast (lhs , rhs ))
349+
350+ return op_expr (lhs , rhs , lambda a , b : a * b )
351+
352+ def rem_expr (lhs , rhs ):
353+ if not use_affine_expr or not check_index_types (lhs , rhs ):
354+ return arith_d .remsi (* _broadcast (lhs , rhs ))
355+
356+ return op_expr (lhs , rhs , lambda a , b : a % b )
357+
358+ def floordiv_expr (lhs , rhs ):
359+ if not use_affine_expr or not check_index_types (lhs , rhs ):
360+ return arith_d .divsi (* _broadcast (lhs , rhs ))
361+
362+ return op_expr (lhs , rhs , lambda a , b : AffineExpr .get_floor_div (a , b ))
363+
364+ def ceildiv_expr (lhs , rhs ):
365+ if not use_affine_expr or not check_index_types (lhs , rhs ):
366+ if _emulate_ceildiv :
367+ # ceildivui(x, y) = x == 0 ? 0 : ((x - 1) / y) + 1
368+ one = _get_const (1 )
369+ zero = _get_const (0 )
370+ lhs_minus_one = arith_d .subi (* _broadcast (lhs , one ))
371+ div = arith_d .divui (* _broadcast (lhs_minus_one , rhs ))
372+ result = arith_d .addi (* _broadcast (div , one ))
373+ cmp = arith_d .cmpi (arith_d .CmpIPredicate .eq , * _broadcast (lhs , zero ))
374+ zero , result = _broadcast (zero , result )
375+ return arith_d .select (cmp , zero , result )
376+ else :
377+ return arith_d .ceildivsi (* _broadcast (lhs , rhs ))
378+
379+ return op_expr (lhs , rhs , lambda a , b : AffineExpr .get_ceil_div (a , b ))
380+
290381 # `x + (a/b)` transformed into `(x*b + a) / b`
291382 def _add (lhs , rhs ):
292383 is_rational_lhs = isinstance (lhs , _Rational )
293384 is_rational_rhs = isinstance (rhs , _Rational )
294385 if is_rational_lhs and not is_rational_rhs :
295- numerator = muli ( * _broadcast ( lhs .denominator , rhs ) )
296- numerator = addi ( * _broadcast ( numerator , lhs .numerator ) )
386+ numerator = muli_expr ( lhs .denominator , rhs )
387+ numerator = add_expr ( numerator , lhs .numerator )
297388 return _Rational (numerator , lhs .denominator )
298389 elif not is_rational_lhs and is_rational_rhs :
299- numerator = muli ( * _broadcast ( lhs , rhs .denominator ) )
300- numerator = addi ( * _broadcast ( numerator , rhs .numerator ) )
390+ numerator = muli_expr ( lhs , rhs .denominator )
391+ numerator = add_expr ( numerator , rhs .numerator )
301392 return _Rational (numerator , rhs .denominator )
302393 elif is_rational_lhs and is_rational_rhs :
303- lhs_numerator = muli ( * _broadcast ( lhs .numerator , rhs .denominator ) )
304- rhs_numerator = muli ( * _broadcast ( rhs .numerator , lhs .denominator ) )
305- numerator = addi ( * _broadcast ( lhs_numerator , rhs_numerator ) )
306- denominator = muli ( * _broadcast ( lhs .denominator , rhs .denominator ) )
394+ lhs_numerator = muli_expr ( lhs .numerator , rhs .denominator )
395+ rhs_numerator = muli_expr ( rhs .numerator , lhs .denominator )
396+ numerator = add_expr ( lhs_numerator , rhs_numerator )
397+ denominator = muli_expr ( lhs .denominator , rhs .denominator )
307398 return _Rational (numerator , denominator )
308399 else :
309- return addi ( * _broadcast ( lhs , rhs ) )
400+ return add_expr ( lhs , rhs )
310401
311402 # `x * (a/b)` transformed into `(x * a) / b`
312403 def _mul (lhs , rhs ):
313404 is_rational_lhs = isinstance (lhs , _Rational )
314405 is_rational_rhs = isinstance (rhs , _Rational )
315406 if is_rational_lhs and not is_rational_rhs :
316- numerator = muli ( * _broadcast ( lhs .numerator , rhs ) )
407+ numerator = muli_expr ( lhs .numerator , rhs )
317408 return _Rational (numerator , lhs .denominator )
318409 elif not is_rational_lhs and is_rational_rhs :
319- numerator = muli ( * _broadcast ( lhs , rhs .numerator ) )
410+ numerator = muli_expr ( lhs , rhs .numerator )
320411 return _Rational (numerator , rhs .denominator )
321412 elif is_rational_lhs and is_rational_rhs :
322- numerator = muli ( * _broadcast ( lhs .numerator , rhs .numerator ) )
323- denominator = muli ( * _broadcast ( lhs .denominator , rhs .denominator ) )
413+ numerator = muli_expr ( lhs .numerator , rhs .numerator )
414+ denominator = muli_expr ( lhs .denominator , rhs .denominator )
324415 return _Rational (numerator , denominator )
325416 else :
326- return muli (* _broadcast (lhs , rhs ))
417+ return muli_expr (lhs , rhs )
418+
419+ def _rem (lhs , rhs ):
420+ assert not isinstance (lhs , _Rational ) and not isinstance (rhs , _Rational )
421+
422+ return rem_expr (lhs , rhs )
327423
328424 def _floor (value ):
329- if isinstance (value , _Rational ):
330- value = arith_d . divsi ( * _broadcast ( value . numerator , value . denominator ))
425+ if not isinstance (value , _Rational ):
426+ return value
331427
332- return value
428+ return floordiv_expr ( value . numerator , value . denominator )
333429
334430 def _ceiling (value ):
335- if isinstance (value , _Rational ):
336- if _emulate_ceildiv :
337- # ceildivui(x, y) = x == 0 ? 0 : ((x - 1) / y) + 1
338- one = _get_const (1 )
339- zero = _get_const (0 )
340- lhs_minus_one = arith_d .subi (* _broadcast (value .numerator , one ))
341- div = arith_d .divui (* _broadcast (lhs_minus_one , value .denominator ))
342- result = arith_d .addi (* _broadcast (div , one ))
343- cmp = arith_d .cmpi (
344- arith_d .CmpIPredicate .eq , * _broadcast (value .numerator , zero )
345- )
346- zero , result = _broadcast (zero , result )
347- value = arith_d .select (cmp , zero , result )
348- else :
349- value = arith_d .ceildivsi (
350- * _broadcast (value .numerator , value .denominator )
351- )
431+ if not isinstance (value , _Rational ):
432+ return value
352433
353- return value
434+ return ceildiv_expr ( value . numerator , value . denominator )
354435
355436 def _group_rationals (stack , count ):
356437 """Group rationals and non-rationals args into 2 contiguous sets.
@@ -441,8 +522,7 @@ def _get_const(val):
441522 lhs = stack .pop ()
442523 _enforce_non_rational (rhs , term )
443524 _enforce_non_rational (lhs , term )
444- mod = arith_d .remsi (* _broadcast (lhs , rhs ))
445- stack .append (mod )
525+ stack .append (_rem (lhs , rhs ))
446526 case sympy .floor ():
447527 stack .append (_floor (stack .pop ()))
448528 case sympy .ceiling ():
@@ -495,6 +575,8 @@ def _get_const(val):
495575 lhs = stack .pop ()
496576 _enforce_non_rational (rhs , term )
497577 _enforce_non_rational (lhs , term )
578+ rhs = _get_ir_value (rhs )
579+ lhs = _get_ir_value (lhs )
498580 elem_type = get_type_or_element_type (rhs .type )
499581 if _is_integer_like_type (elem_type ):
500582 res = arith_d .maxsi (* _broadcast (lhs , rhs ))
@@ -506,6 +588,8 @@ def _get_const(val):
506588 lhs = stack .pop ()
507589 _enforce_non_rational (rhs , term )
508590 _enforce_non_rational (lhs , term )
591+ rhs = _get_ir_value (rhs )
592+ lhs = _get_ir_value (lhs )
509593 elem_type = get_type_or_element_type (rhs .type )
510594 if _is_integer_like_type (elem_type ):
511595 res = arith_d .minsi (* _broadcast (lhs , rhs ))
@@ -555,7 +639,7 @@ def _get_const(val):
555639 if len (stack ) != 1 or isinstance (stack [0 ], _Rational ):
556640 raise CodegenError (f"Expected single result, got { len (stack )} " )
557641
558- return stack [0 ]
642+ return _get_ir_value ( stack [0 ])
559643
560644
561645def get_constant_attr (value : Any , element_type : IrType ) -> Attribute :
0 commit comments