Skip to content

Commit e25b3ca

Browse files
xintinraikonenfnu
authored andcommitted
Custom mask bshd attention variant (#665)
Added support for custom mask bshd kernel Added `operator.eq` and `bitwise or` Added implementation for `scaled_dot_product_attention` to verify the numeric. --------- Signed-off-by: Stanley Winata <[email protected]> Signed-off-by: xintin <[email protected]> Co-authored-by: Stanley Winata <[email protected]> Signed-off-by: nithinsubbiah <[email protected]>
1 parent c09eff4 commit e25b3ca

File tree

8 files changed

+538
-17
lines changed

8 files changed

+538
-17
lines changed

iree/turbine/aot/support/ir_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,10 @@ def _is_integer_like_type(type):
497497
return isinstance(type, (IntegerType, IndexType))
498498

499499

500+
def _is_signed_or_signless_type(type):
501+
return getattr(type, "is_signed", False) or getattr(type, "is_signless", False)
502+
503+
500504
def _attribute_from_device_affinity(
501505
affinity: DeviceAffinity, context: Context
502506
) -> Attribute:

iree/turbine/kernel/ops/wave_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,10 @@ def le(lhs: "Register", rhs: "Register") -> "Register":
202202
...
203203

204204

205+
def eq(lhs: "Register", rhs: "Register") -> "Register":
206+
...
207+
208+
205209
def cast(src: "Register", dtype: DataType) -> "Register":
206210
...
207211

@@ -780,6 +784,7 @@ def infer_shape(self) -> Any:
780784
@define_py_op(operator.sub)
781785
@define_py_op(operator.mul)
782786
@define_py_op(operator.and_)
787+
@define_py_op(operator.or_)
783788
@define_py_op(operator.truediv)
784789
@define_interface_op("maximum")
785790
@define_interface_op("minimum")
@@ -789,10 +794,12 @@ def infer_type(self):
789794
self.type = Register[(*self.infer_shape(), get_custom(self.lhs).type.dtype)]
790795

791796

797+
@define_py_op(operator.eq)
792798
@define_py_op(operator.gt)
793799
@define_py_op(operator.ge)
794800
@define_py_op(operator.lt)
795801
@define_py_op(operator.le)
802+
@define_interface_op("eq")
796803
@define_interface_op("gt")
797804
@define_interface_op("ge")
798805
@define_interface_op("lt")

iree/turbine/kernel/wave/codegen/handlers.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
_is_float_type,
4848
_is_index_type,
4949
_is_integer_like_type,
50+
_is_signed_or_signless_type,
5051
)
5152

5253
# TK infrastructure imports.
@@ -58,6 +59,7 @@
5859
broadcast,
5960
cast,
6061
conditional,
62+
eq,
6163
exp2,
6264
extract,
6365
extract_slice,
@@ -444,8 +446,8 @@ def handle_div(lhs: Value, rhs: Value) -> OpResult:
444446
element_type = get_type_or_element_type(lhs.type)
445447
if _is_float_type(element_type):
446448
result = arith_d.divf(lhs, rhs)
447-
elif _is_integer_like_type(element_type) and (
448-
element_type.is_signed or element_type.is_signless
449+
elif _is_integer_like_type(element_type) and _is_signed_or_signless_type(
450+
element_type
449451
):
450452
result = arith_d.divsi(lhs, rhs)
451453
else:
@@ -456,12 +458,28 @@ def handle_div(lhs: Value, rhs: Value) -> OpResult:
456458
@handle_binary_op(operator.and_)
457459
def handle_and(lhs: Value, rhs: Value) -> OpResult:
458460
element_type = get_type_or_element_type(lhs.type)
459-
if _is_integer_like_type(element_type) and (
460-
element_type.is_signed or element_type.is_signless
461+
if _is_integer_like_type(element_type) and _is_signed_or_signless_type(
462+
element_type
461463
):
462464
result = arith_d.andi(lhs, rhs)
463465
else:
464-
raise ValidationError(f"Found unhandled operand type for div: {element_type}")
466+
raise ValidationError(
467+
f"Found unhandled operand type for bitwise and: {element_type}"
468+
)
469+
return result
470+
471+
472+
@handle_binary_op(operator.or_)
473+
def handle_or(lhs: Value, rhs: Value) -> OpResult:
474+
element_type = get_type_or_element_type(lhs.type)
475+
if _is_integer_like_type(element_type) and _is_signed_or_signless_type(
476+
element_type
477+
):
478+
result = arith_d.ori(lhs, rhs)
479+
else:
480+
raise ValidationError(
481+
f"Found unhandled operand type for bitwise or: {element_type}"
482+
)
465483
return result
466484

467485

@@ -470,8 +488,8 @@ def handle_gt(lhs: Value, rhs: Value) -> OpResult:
470488
element_type = get_type_or_element_type(lhs.type)
471489
if _is_float_type(element_type):
472490
result = arith_d.cmpi(arith_d.CmpFPredicate.OGT, lhs, rhs)
473-
elif _is_integer_like_type(element_type) and (
474-
element_type.is_signed or element_type.is_signless
491+
elif _is_integer_like_type(element_type) and _is_signed_or_signless_type(
492+
element_type
475493
):
476494
result = arith_d.cmpi(arith_d.CmpIPredicate.sgt, lhs, rhs)
477495
else:
@@ -484,8 +502,8 @@ def handle_ge(lhs: Value, rhs: Value) -> OpResult:
484502
element_type = get_type_or_element_type(lhs.type)
485503
if _is_float_type(element_type):
486504
result = arith_d.cmpi(arith_d.CmpFPredicate.OGE, lhs, rhs)
487-
elif _is_integer_like_type(element_type) and (
488-
element_type.is_signed or element_type.is_signless
505+
elif _is_integer_like_type(element_type) and _is_signed_or_signless_type(
506+
element_type
489507
):
490508
result = arith_d.cmpi(arith_d.CmpIPredicate.sge, lhs, rhs)
491509
else:
@@ -498,8 +516,8 @@ def handle_lt(lhs: Value, rhs: Value) -> OpResult:
498516
element_type = get_type_or_element_type(lhs.type)
499517
if _is_float_type(element_type):
500518
result = arith_d.cmpi(arith_d.CmpFPredicate.OLT, lhs, rhs)
501-
elif _is_integer_like_type(element_type) and (
502-
element_type.is_signed or element_type.is_signless
519+
elif _is_integer_like_type(element_type) and _is_signed_or_signless_type(
520+
element_type
503521
):
504522
result = arith_d.cmpi(arith_d.CmpIPredicate.slt, lhs, rhs)
505523
else:
@@ -512,22 +530,36 @@ def handle_le(lhs: Value, rhs: Value) -> OpResult:
512530
element_type = get_type_or_element_type(lhs.type)
513531
if _is_float_type(element_type):
514532
result = arith_d.cmpi(arith_d.CmpFPredicate.OLE, lhs, rhs)
515-
elif _is_integer_like_type(element_type) and (
516-
element_type.is_signed or element_type.is_signless
533+
elif _is_integer_like_type(element_type) and _is_signed_or_signless_type(
534+
element_type
517535
):
518536
result = arith_d.cmpi(arith_d.CmpIPredicate.sle, lhs, rhs)
519537
else:
520538
raise ValidationError(f"Found unhandled operand type for le: {element_type}")
521539
return result
522540

523541

542+
@handle_binary_op([operator.eq, eq])
543+
def handle_le(lhs: Value, rhs: Value) -> OpResult:
544+
element_type = get_type_or_element_type(lhs.type)
545+
if _is_float_type(element_type):
546+
result = arith_d.cmpf(arith_d.CmpFPredicate.OEQ, lhs, rhs)
547+
elif _is_integer_like_type(element_type) and _is_signed_or_signless_type(
548+
element_type
549+
):
550+
result = arith_d.cmpi(arith_d.CmpIPredicate.eq, lhs, rhs)
551+
else:
552+
raise ValidationError(f"Found unhandled operand type for eq: {element_type}")
553+
return result
554+
555+
524556
@handle_binary_op(maximum)
525557
def handle_maximum(lhs: Value, rhs: Value) -> OpResult:
526558
element_type = get_type_or_element_type(lhs.type)
527559
if _is_float_type(element_type):
528560
result = arith_d.maximumf(lhs, rhs)
529-
elif _is_integer_like_type(element_type) and (
530-
element_type.is_signed or element_type.is_signless
561+
elif _is_integer_like_type(element_type) and _is_signed_or_signless_type(
562+
element_type
531563
):
532564
result = arith_d.maxsi(lhs, rhs)
533565
else:
@@ -542,8 +574,8 @@ def handle_minimum(lhs: Value, rhs: Value) -> OpResult:
542574
element_type = get_type_or_element_type(lhs.type)
543575
if _is_float_type(element_type):
544576
result = arith_d.minimumf(lhs, rhs)
545-
elif _is_integer_like_type(element_type) and (
546-
element_type.is_signed or element_type.is_signless
577+
elif _is_integer_like_type(element_type) and _is_signed_or_signless_type(
578+
element_type
547579
):
548580
result = arith_d.minsi(lhs, rhs)
549581
else:

0 commit comments

Comments
 (0)