Skip to content

Commit 87c0c8c

Browse files
Fix attention + enable VMMA (#43)
Update attention to work ToM and support VMMA: 1. Update translation info to use "pipeline" 2. Update attention IREE IR to specify QK and KV MMA schedule separately S.T it works with ToM IREE 3. Refactor to use enum.Enum to represent intrinsics 4. Add VMMA support and helper functions to maximize perf --------- Signed-off-by: Stanley Winata <[email protected]> Co-authored-by: saienduri <[email protected]>
1 parent f0bd8a1 commit 87c0c8c

File tree

1 file changed

+87
-18
lines changed

1 file changed

+87
-18
lines changed

attentionbench/attention_utils.py

Lines changed: 87 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,67 @@
22
from dataclasses import dataclass
33
from pathlib import Path
44
from typing import Optional
5+
from enum import Enum
56

67

8+
class IntrinsicType(Enum):
9+
"""
10+
Formatting for different target intrinsics:
11+
<kind>_<elem-type-C>_<M>x<N>x<K>_<elem-type-A>[_<elem-type-B>]
12+
13+
Values: 0xABCD where:
14+
* A = vendor:
15+
* 1 = AMD
16+
* 2 = NVIDIA
17+
* B = architecture. When an intrinsic exists in multiple architectures, this
18+
should be the architecture it was introduced in, as long as it still
19+
has the same semantics. If a new architecture breaks an existing
20+
intrinsic's semantics, we can use that field for versioning.
21+
* For AMD:
22+
* 0 = CDNA1
23+
* 1 = CDNA2
24+
* 2 = CDNA3
25+
* 8 = RDNA3
26+
* C = element type of A-matrix:
27+
* 0 = 64-bit float (e.g. IEEE754 double precision)
28+
* 1 = 32-bit float (e.g. IEEE754 single precision, and "xf32" fast variants)
29+
* 2 = 16-bit float (incl. IREE754 half and bf16)
30+
* 3 = 8-bit float (incl. f8E5M2, f8E4M3, and "FNUZ" variants)
31+
* C = 8-bit integer (any signedness)
32+
* D enumerates intrinsics that share the same 0xABC* bits.
33+
"""
34+
# Intrinsics introduced in CDNA1
35+
MFMA_F32_16x16x16_F16 = 0x1020
36+
MFMA_F32_32x32x8_F16 = 0x1021
37+
VMFMA_F32_32x32x16_F16 = 0x1022
38+
MFMA_I32_16x16x16_I8 = 0x10C0
39+
MFMA_I32_32x32x8_I8 = 0x10C1
40+
41+
# Intrinsics introduced in CDNA3
42+
MFMA_F32_16x16x32_F8 = 0x1230
43+
MFMA_F32_32x32x16_F8 = 0x1231
44+
MFMA_I32_16x16x32_I8 = 0x12C0
45+
MFMA_I32_32x32x16_I8 = 0x12C1
46+
47+
48+
def get_intrinsic_string(intrinsic: IntrinsicType):
49+
match intrinsic:
50+
case IntrinsicType.VMFMA_F32_32x32x16_F16:
51+
return f"#iree_gpu.virtual_mma_layout<intrinsic = {intrinsic.name}>"
52+
case _:
53+
return f"#iree_gpu.mma_layout<{intrinsic.name}>"
54+
55+
def get_pv_intrinsic(intrinsic: IntrinsicType):
56+
"""
57+
QK intrinsics and PV intrinsics can differ. Mostly used for
58+
selecting VMFMA for QK to maximize contiguous read from shared memory.
59+
"""
60+
match intrinsic:
61+
case IntrinsicType.VMFMA_F32_32x32x16_F16:
62+
return IntrinsicType.MFMA_F32_32x32x8_F16
63+
case _:
64+
return intrinsic
65+
766
@dataclass
867
class AttentionConfig:
968
B: int
@@ -71,20 +130,11 @@ def get_lowering_config(self) -> str:
71130
+ "{ "
72131
+ f"workgroup = [{', '.join(map(str, self.wg_tiles))}], "
73132
+ f"reduction = [{', '.join(map(str, self.reduction_tiles))}],"
74-
+ f"promote_operands = [0, 1, 2]"
133+
+ f"promote_operands = [1, 2]"
75134
+ " }"
76135
+ f">"
77136
)
78137

79-
def get_mma_schedule(self) -> str:
80-
return (
81-
f"#iree_gpu.mma_schedule<"
82-
+ f"intrinsic = #iree_gpu.mma_layout<{self.intrinsic}>"
83-
+ f", subgroup_m_count = {self.M_warp}"
84-
+ f", subgroup_n_count = {self.N_warp}"
85-
+ f">"
86-
)
87-
88138
def get_translation_info(self) -> str:
89139
llvm_func_attrs = []
90140
if self.waves_per_eu:
@@ -93,11 +143,10 @@ def get_translation_info(self) -> str:
93143
llvm_func_attrs += [f'"denormal-fp-math-f32" = "preserve-sign"']
94144
return (
95145
f"#iree_codegen.translation_info<"
96-
+ f"LLVMGPUVectorDistribute"
146+
+ f"pipeline = LLVMGPUVectorDistribute"
97147
+ f" workgroup_size = [{self.N_warp * self.M_warp * 64}]"
98148
+ f" subgroup_size = 64"
99-
+ f" ,{{mma_schedule = {self.get_mma_schedule()}"
100-
+ f" , llvm_func_attrs = {{ {','.join(llvm_func_attrs)} }}"
149+
+ f" , {{llvm_func_attrs = {{ {','.join(llvm_func_attrs)} }}"
101150
+ f"}}"
102151
+ f">"
103152
)
@@ -110,6 +159,26 @@ def get_compilation_info(self) -> str:
110159
+ f">"
111160
)
112161

162+
def get_qk_config_info(self) -> str:
163+
return (
164+
f"#iree_gpu.lowering_config<{{"
165+
+ f"mma_kind = {get_intrinsic_string(self.intrinsic)}"
166+
+ f", subgroup_m_count = {self.M_warp}"
167+
+ f", subgroup_n_count = {self.N_warp}"
168+
+ f", promote_operands = [1]"
169+
+ f"}}>"
170+
)
171+
172+
def get_pv_config_info(self) -> str:
173+
return (
174+
f"#iree_gpu.lowering_config<{{"
175+
+ f"mma_kind = {get_intrinsic_string(get_pv_intrinsic(self.intrinsic))}"
176+
+ f", subgroup_m_count = {self.M_warp}"
177+
+ f", subgroup_n_count = {self.N_warp}"
178+
+ f", promote_operands = [1]"
179+
+ f"}}>"
180+
)
181+
113182

114183
def generate_mlir(config: AttentionConfig, tuning: Optional[TuningSpec] = None):
115184
shapes = f"""\
@@ -136,11 +205,11 @@ def generate_mlir(config: AttentionConfig, tuning: Optional[TuningSpec] = None):
136205
func.func @main(%Q : !Q, %K : !K, %V : !V) -> !O {{
137206
%scale = arith.constant 1.0 : !dtype
138207
%empty = tensor.empty() : !O
139-
%O = iree_linalg_ext.attention
208+
%O = iree_linalg_ext.attention
140209
{{ indexing_maps = [#Q, #K, #V, #S, #O]
141210
,decomposition_config = {{
142-
qk_attrs = {{attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{{promote_operands = [0, 1]}}>}},
143-
pv_attrs = {{attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{{promote_operands = [1]}}>}}
211+
qk_attrs = {{attention_qk_matmul, lowering_config = {tuning.get_qk_config_info()}}},
212+
pv_attrs = {{attention_pv_matmul, lowering_config = {tuning.get_pv_config_info()}}}
144213
}}
145214
{",compilation_info = #tuning" if tuning and config.dtype == "f16" else ""}
146215
}}
@@ -168,7 +237,7 @@ def compile_attention_config(
168237

169238
# TODO: Use different tuning specs for different configs. This is just a
170239
# general tuning config that worked well for sdxl shapes.
171-
spec = TuningSpec([1, 128, 0, 0, 0], [0, 0, 0, 0, 32], 4, 1, "MFMA_F32_32x32x8_F16", 2, True)
240+
spec = TuningSpec([1, 128, 0, 0, 0], [0, 0, 0, 0, 32], 4, 1, IntrinsicType.VMFMA_F32_32x32x16_F16, 2, True)
172241
# Generate mlir content
173242
mlir_content = generate_mlir(config, spec)
174243

@@ -211,5 +280,5 @@ def compile_attention_config(
211280
# Dummy test generation
212281
if __name__ == "__main__":
213282
config = AttentionConfig(20, 4096, 64, 64, 4096, "f16")
214-
spec = TuningSpec([1, 128, 0, 0, 0], [0, 0, 0, 0, 32], 4, 1, "MFMA_F32_32x32x8_F16", 2, True)
283+
spec = TuningSpec([1, 128, 0, 0, 0], [0, 0, 0, 0, 32], 4, 1, IntrinsicType.VMFMA_F32_32x32x16_F16, 2, True)
215284
print(generate_mlir(config, spec))

0 commit comments

Comments
 (0)