Skip to content

Commit abeb3a0

Browse files
Hardcode84nithinsubbiah
authored andcommitted
[TKW] Use affine apply (#666)
* Add ability to use `affine.apply` for index codegen * Limited to non-vector index type for now, need llvm/llvm-project#129442 for vector support --------- Signed-off-by: Ivan Butygin <[email protected]> Signed-off-by: nithinsubbiah <[email protected]>
1 parent d77d213 commit abeb3a0

File tree

8 files changed

+531
-750
lines changed

8 files changed

+531
-750
lines changed

iree/turbine/kernel/compiler/ir.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@
3636
)
3737

3838
from iree.compiler.dialects import (
39-
arith as arith_d,
39+
affine as affine_d,
4040
amdgpu as amdgpu_d,
41+
arith as arith_d,
4142
builtin as builtin_d,
4243
flow as flow_d,
4344
func as func_d,
@@ -47,8 +48,8 @@
4748
math as math_d,
4849
memref as memref_d,
4950
rocdl as rocdl_d,
50-
stream as stream_d,
5151
scf as scf_d,
52+
stream as stream_d,
5253
transform as transform_d,
5354
vector as vector_d,
5455
)

iree/turbine/kernel/wave/codegen/emitter.py

Lines changed: 145 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7+
from os import environ
78
import sympy
89
from typing import Any, Callable, ClassVar, Optional, List, Type, Dict
910
from dataclasses import dataclass
@@ -21,6 +22,8 @@
2122

2223

2324
from ...compiler.ir import (
25+
AffineExpr,
26+
AffineMap,
2427
Attribute,
2528
DenseElementsAttr,
2629
FloatAttr,
@@ -34,6 +37,7 @@
3437
ShapedType,
3538
Value,
3639
VectorType,
40+
affine_d,
3741
arith_d,
3842
func_d,
3943
gpu_d,
@@ -198,55 +202,77 @@ def add_emitter_subs(
198202
return dynamics
199203

200204

201-
_emulate_ceildiv = False
205+
_emulate_ceildiv = bool(int(environ.get("WAVE_EMULATE_CEILDIV", 0)))
206+
_use_affine_expr = bool(int(environ.get("WAVE_USE_AFFINE_EXPR", 1)))
202207

203208
_Rational = namedtuple("_Rational", ["numerator", "denominator"])
209+
_ApplyExpr = namedtuple("_ApplyExpr", ["expr", "args"])
204210

205211

206-
def gen_sympy_index(dynamics: dict[IndexSymbol, Value], expr: sympy.Expr) -> OpResult:
212+
def gen_sympy_index(dynamics: dict[IndexSymbol, Value], expr: sympy.Expr) -> Value:
213+
use_affine_expr = _use_affine_expr
207214
stack: list[OpResult] = []
208215

209216
def _get_ir_value(arg):
217+
if isinstance(arg, _ApplyExpr):
218+
args = _broadcast(*arg.args)
219+
expr = arg.expr
220+
expr = AffineMap.get(dim_count=0, symbol_count=len(args), exprs=[expr])
221+
222+
return affine_d.apply(expr, args)
223+
210224
if not isinstance(arg, (Value, OpResult)):
211225
arg = arg.result
212226

213227
return arg
214228

215229
def _check_vec_scalar(a, b):
216-
if not isinstance(a.type, VectorType):
230+
if not isinstance(a, VectorType):
217231
return False
218232

219-
if a.type.element_type == b.type:
233+
if a.element_type == b:
220234
return True
221235

222236
return (
223-
isinstance(b.type, VectorType)
224-
and b.type.shape == [1]
225-
and a.type.element_type == b.type.element_type
237+
isinstance(b, VectorType)
238+
and b.shape == [1]
239+
and a.element_type == b.element_type
226240
)
227241

228-
def _broadcast(a, b):
229-
a = _get_ir_value(a)
230-
b = _get_ir_value(b)
242+
def _broadcast(*args):
243+
assert len(args) > 0
244+
if len(args) == 1:
245+
return args
246+
247+
res_args = [_get_ir_value(a) for a in args]
248+
res_type = res_args[0].type
249+
for arg in res_args[1:]:
250+
arg_type = arg.type
251+
if arg_type == res_type:
252+
continue
253+
254+
if _check_vec_scalar(res_type, arg_type):
255+
# broadcast to res_type
256+
continue
257+
258+
if _check_vec_scalar(arg_type, res_type):
259+
res_type = arg_type
260+
continue
231261

232-
if a.type == b.type:
233-
return a, b
262+
raise CodegenError(f"Cannot broadcast {res_type} and {arg.type}")
234263

235-
if _check_vec_scalar(a, b):
236-
if isinstance(b.type, VectorType):
237-
b = vector_d.extract(b, static_position=[0], dynamic_position=[])
264+
for i, arg in enumerate(res_args):
265+
if arg.type == res_type:
266+
continue
238267

239-
b = vector_d.splat(a.type, b)
240-
return a, b
268+
if isinstance(arg.type, VectorType):
269+
arg = vector_d.extract(arg, static_position=[0], dynamic_position=[])
241270

242-
if _check_vec_scalar(b, a):
243-
if isinstance(a.type, VectorType):
244-
a = vector_d.extract(a, static_position=[0], dynamic_position=[])
271+
res_args[i] = vector_d.splat(res_type, arg)
245272

246-
a = vector_d.splat(b.type, a)
247-
return a, b
273+
assert all(arg.type == res_type for arg in res_args)
248274

249-
raise CodegenError(f"Cannot broadcast {a.type} and {b.type}")
275+
return tuple(res_args)
250276

251277
def get_const_val(arg):
252278
if isinstance(arg, OpResult):
@@ -287,70 +313,125 @@ def addi(lhs, rhs):
287313

288314
return arith_d.addi(lhs, rhs, overflow_flags=overflow_flags)
289315

316+
def op_expr(lhs, rhs, op):
317+
if isinstance(lhs, _ApplyExpr):
318+
lhs_args = lhs.args
319+
lhs_expr = lhs.expr
320+
else:
321+
lhs_args = [lhs]
322+
lhs_expr = AffineExpr.get_symbol(0)
323+
324+
if isinstance(rhs, _ApplyExpr):
325+
rhs_args = rhs.args
326+
rhs_expr = rhs.expr
327+
else:
328+
rhs_args = [rhs]
329+
rhs_expr = AffineExpr.get_symbol(0)
330+
331+
args = lhs_args + rhs_args
332+
expr = op(lhs_expr, rhs_expr.shift_symbols(len(rhs_args), len(lhs_args)))
333+
return _ApplyExpr(expr, args)
334+
335+
def check_index_types(*args):
336+
return all(
337+
isinstance(a, _ApplyExpr) or isinstance(a.type, IndexType) for a in args
338+
)
339+
340+
def add_expr(lhs, rhs):
341+
if not use_affine_expr or not check_index_types(lhs, rhs):
342+
return addi(*_broadcast(lhs, rhs))
343+
344+
return op_expr(lhs, rhs, lambda a, b: a + b)
345+
346+
def muli_expr(lhs, rhs):
347+
if not use_affine_expr or not check_index_types(lhs, rhs):
348+
return muli(*_broadcast(lhs, rhs))
349+
350+
return op_expr(lhs, rhs, lambda a, b: a * b)
351+
352+
def rem_expr(lhs, rhs):
353+
if not use_affine_expr or not check_index_types(lhs, rhs):
354+
return arith_d.remsi(*_broadcast(lhs, rhs))
355+
356+
return op_expr(lhs, rhs, lambda a, b: a % b)
357+
358+
def floordiv_expr(lhs, rhs):
359+
if not use_affine_expr or not check_index_types(lhs, rhs):
360+
return arith_d.divsi(*_broadcast(lhs, rhs))
361+
362+
return op_expr(lhs, rhs, lambda a, b: AffineExpr.get_floor_div(a, b))
363+
364+
def ceildiv_expr(lhs, rhs):
365+
if not use_affine_expr or not check_index_types(lhs, rhs):
366+
if _emulate_ceildiv:
367+
# ceildivui(x, y) = x == 0 ? 0 : ((x - 1) / y) + 1
368+
one = _get_const(1)
369+
zero = _get_const(0)
370+
lhs_minus_one = arith_d.subi(*_broadcast(lhs, one))
371+
div = arith_d.divui(*_broadcast(lhs_minus_one, rhs))
372+
result = arith_d.addi(*_broadcast(div, one))
373+
cmp = arith_d.cmpi(arith_d.CmpIPredicate.eq, *_broadcast(lhs, zero))
374+
zero, result = _broadcast(zero, result)
375+
return arith_d.select(cmp, zero, result)
376+
else:
377+
return arith_d.ceildivsi(*_broadcast(lhs, rhs))
378+
379+
return op_expr(lhs, rhs, lambda a, b: AffineExpr.get_ceil_div(a, b))
380+
290381
# `x + (a/b)` transformed into `(x*b + a) / b`
291382
def _add(lhs, rhs):
292383
is_rational_lhs = isinstance(lhs, _Rational)
293384
is_rational_rhs = isinstance(rhs, _Rational)
294385
if is_rational_lhs and not is_rational_rhs:
295-
numerator = muli(*_broadcast(lhs.denominator, rhs))
296-
numerator = addi(*_broadcast(numerator, lhs.numerator))
386+
numerator = muli_expr(lhs.denominator, rhs)
387+
numerator = add_expr(numerator, lhs.numerator)
297388
return _Rational(numerator, lhs.denominator)
298389
elif not is_rational_lhs and is_rational_rhs:
299-
numerator = muli(*_broadcast(lhs, rhs.denominator))
300-
numerator = addi(*_broadcast(numerator, rhs.numerator))
390+
numerator = muli_expr(lhs, rhs.denominator)
391+
numerator = add_expr(numerator, rhs.numerator)
301392
return _Rational(numerator, rhs.denominator)
302393
elif is_rational_lhs and is_rational_rhs:
303-
lhs_numerator = muli(*_broadcast(lhs.numerator, rhs.denominator))
304-
rhs_numerator = muli(*_broadcast(rhs.numerator, lhs.denominator))
305-
numerator = addi(*_broadcast(lhs_numerator, rhs_numerator))
306-
denominator = muli(*_broadcast(lhs.denominator, rhs.denominator))
394+
lhs_numerator = muli_expr(lhs.numerator, rhs.denominator)
395+
rhs_numerator = muli_expr(rhs.numerator, lhs.denominator)
396+
numerator = add_expr(lhs_numerator, rhs_numerator)
397+
denominator = muli_expr(lhs.denominator, rhs.denominator)
307398
return _Rational(numerator, denominator)
308399
else:
309-
return addi(*_broadcast(lhs, rhs))
400+
return add_expr(lhs, rhs)
310401

311402
# `x * (a/b)` transformed into `(x * a) / b`
312403
def _mul(lhs, rhs):
313404
is_rational_lhs = isinstance(lhs, _Rational)
314405
is_rational_rhs = isinstance(rhs, _Rational)
315406
if is_rational_lhs and not is_rational_rhs:
316-
numerator = muli(*_broadcast(lhs.numerator, rhs))
407+
numerator = muli_expr(lhs.numerator, rhs)
317408
return _Rational(numerator, lhs.denominator)
318409
elif not is_rational_lhs and is_rational_rhs:
319-
numerator = muli(*_broadcast(lhs, rhs.numerator))
410+
numerator = muli_expr(lhs, rhs.numerator)
320411
return _Rational(numerator, rhs.denominator)
321412
elif is_rational_lhs and is_rational_rhs:
322-
numerator = muli(*_broadcast(lhs.numerator, rhs.numerator))
323-
denominator = muli(*_broadcast(lhs.denominator, rhs.denominator))
413+
numerator = muli_expr(lhs.numerator, rhs.numerator)
414+
denominator = muli_expr(lhs.denominator, rhs.denominator)
324415
return _Rational(numerator, denominator)
325416
else:
326-
return muli(*_broadcast(lhs, rhs))
417+
return muli_expr(lhs, rhs)
418+
419+
def _rem(lhs, rhs):
420+
assert not isinstance(lhs, _Rational) and not isinstance(rhs, _Rational)
421+
422+
return rem_expr(lhs, rhs)
327423

328424
def _floor(value):
329-
if isinstance(value, _Rational):
330-
value = arith_d.divsi(*_broadcast(value.numerator, value.denominator))
425+
if not isinstance(value, _Rational):
426+
return value
331427

332-
return value
428+
return floordiv_expr(value.numerator, value.denominator)
333429

334430
def _ceiling(value):
335-
if isinstance(value, _Rational):
336-
if _emulate_ceildiv:
337-
# ceildivui(x, y) = x == 0 ? 0 : ((x - 1) / y) + 1
338-
one = _get_const(1)
339-
zero = _get_const(0)
340-
lhs_minus_one = arith_d.subi(*_broadcast(value.numerator, one))
341-
div = arith_d.divui(*_broadcast(lhs_minus_one, value.denominator))
342-
result = arith_d.addi(*_broadcast(div, one))
343-
cmp = arith_d.cmpi(
344-
arith_d.CmpIPredicate.eq, *_broadcast(value.numerator, zero)
345-
)
346-
zero, result = _broadcast(zero, result)
347-
value = arith_d.select(cmp, zero, result)
348-
else:
349-
value = arith_d.ceildivsi(
350-
*_broadcast(value.numerator, value.denominator)
351-
)
431+
if not isinstance(value, _Rational):
432+
return value
352433

353-
return value
434+
return ceildiv_expr(value.numerator, value.denominator)
354435

355436
def _group_rationals(stack, count):
356437
"""Group rationals and non-rationals args into 2 contiguous sets.
@@ -441,8 +522,7 @@ def _get_const(val):
441522
lhs = stack.pop()
442523
_enforce_non_rational(rhs, term)
443524
_enforce_non_rational(lhs, term)
444-
mod = arith_d.remsi(*_broadcast(lhs, rhs))
445-
stack.append(mod)
525+
stack.append(_rem(lhs, rhs))
446526
case sympy.floor():
447527
stack.append(_floor(stack.pop()))
448528
case sympy.ceiling():
@@ -495,6 +575,8 @@ def _get_const(val):
495575
lhs = stack.pop()
496576
_enforce_non_rational(rhs, term)
497577
_enforce_non_rational(lhs, term)
578+
rhs = _get_ir_value(rhs)
579+
lhs = _get_ir_value(lhs)
498580
elem_type = get_type_or_element_type(rhs.type)
499581
if _is_integer_like_type(elem_type):
500582
res = arith_d.maxsi(*_broadcast(lhs, rhs))
@@ -506,6 +588,8 @@ def _get_const(val):
506588
lhs = stack.pop()
507589
_enforce_non_rational(rhs, term)
508590
_enforce_non_rational(lhs, term)
591+
rhs = _get_ir_value(rhs)
592+
lhs = _get_ir_value(lhs)
509593
elem_type = get_type_or_element_type(rhs.type)
510594
if _is_integer_like_type(elem_type):
511595
res = arith_d.minsi(*_broadcast(lhs, rhs))
@@ -555,7 +639,7 @@ def _get_const(val):
555639
if len(stack) != 1 or isinstance(stack[0], _Rational):
556640
raise CodegenError(f"Expected single result, got {len(stack)}")
557641

558-
return stack[0]
642+
return _get_ir_value(stack[0])
559643

560644

561645
def get_constant_attr(value: Any, element_type: IrType) -> Attribute:

lit_tests/kernel/wave/attention/decode_attention.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,16 @@ def test_paged_flash_decoding():
5757
phase_0 = wave_compile(options, phase_0)
5858
print(phase_0.asm)
5959

60-
# CHECK-LABEL: func.func @phase_0
60+
# CHECK-LABEL: test_paged_flash_decoding
61+
# CHECK-DAG: #[[map:.*]] = affine_map<()[s0] -> (s0 ceildiv 16)>
62+
# CHECK: func.func @phase_0
6163
# CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
6264
# CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
63-
# CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
6465
# CHECK-COUNT-2: vector.load
6566
# CHECK: %[[COUNT0:.*]] = arith.minsi %{{.*}}, %{{.*}} : vector<1xindex>
6667
# CHECK: %[[COUNT1:.*]] = vector.extract %[[COUNT0]][0] : index from vector<1xindex>
6768
# CHECK-COUNT-2: vector.load
68-
# CHECK: %[[COUNT2:.*]] = arith.ceildivsi %[[COUNT1]], %[[C16]] : index
69+
# CHECK: %[[COUNT2:.*]] = affine.apply #[[map]]()[%[[COUNT1]]]
6970
# CHECK: scf.for %{{.*}} = %[[C0]] to %[[COUNT2]] step %[[C1]]
7071
# CHECK: amdgpu.lds_barrier
7172
# 1 masked load block table, 1 for k_cache, and 1 for v_cache.

0 commit comments

Comments
 (0)