22from dataclasses import dataclass
33from pathlib import Path
44from typing import Optional
5+ from enum import Enum
56
67
8+ class IntrinsicType (Enum ):
9+ """
10+ Formatting for different target intrinsics:
11+ <kind>_<elem-type-C>_<M>x<N>x<K>_<elem-type-A>[_<elem-type-B>]
12+
13+ Values: 0xABCD where:
14+ * A = vendor:
15+ * 1 = AMD
16+ * 2 = NVIDIA
17+ * B = architecture. When an intrinsic exists in multiple architectures, this
18+ should be the architecture it was introduced in, as long as it still
19+ has the same semantics. If a new architecture breaks an existing
20+ intrinsic's semantics, we can use that field for versioning.
21+ * For AMD:
22+ * 0 = CDNA1
23+ * 1 = CDNA2
24+ * 2 = CDNA3
25+ * 8 = RDNA3
26+ * C = element type of A-matrix:
27+ * 0 = 64-bit float (e.g. IEEE754 double precision)
28+ * 1 = 32-bit float (e.g. IEEE754 single precision, and "xf32" fast variants)
29+ * 2 = 16-bit float (incl. IREE754 half and bf16)
30+ * 3 = 8-bit float (incl. f8E5M2, f8E4M3, and "FNUZ" variants)
31+ * C = 8-bit integer (any signedness)
32+ * D enumerates intrinsics that share the same 0xABC* bits.
33+ """
34+ # Intrinsics introduced in CDNA1
35+ MFMA_F32_16x16x16_F16 = 0x1020
36+ MFMA_F32_32x32x8_F16 = 0x1021
37+ VMFMA_F32_32x32x16_F16 = 0x1022
38+ MFMA_I32_16x16x16_I8 = 0x10C0
39+ MFMA_I32_32x32x8_I8 = 0x10C1
40+
41+ # Intrinsics introduced in CDNA3
42+ MFMA_F32_16x16x32_F8 = 0x1230
43+ MFMA_F32_32x32x16_F8 = 0x1231
44+ MFMA_I32_16x16x32_I8 = 0x12C0
45+ MFMA_I32_32x32x16_I8 = 0x12C1
46+
47+
48+ def get_intrinsic_string (intrinsic : IntrinsicType ):
49+ match intrinsic :
50+ case IntrinsicType .VMFMA_F32_32x32x16_F16 :
51+ return f"#iree_gpu.virtual_mma_layout<intrinsic = { intrinsic .name } >"
52+ case _:
53+ return f"#iree_gpu.mma_layout<{ intrinsic .name } >"
54+
55+ def get_pv_intrinsic (intrinsic : IntrinsicType ):
56+ """
57+ QK intrinsics and PV intrinsics can differ. Mostly used for
58+ selecting VMFMA for QK to maximize contiguous read from shared memory.
59+ """
60+ match intrinsic :
61+ case IntrinsicType .VMFMA_F32_32x32x16_F16 :
62+ return IntrinsicType .MFMA_F32_32x32x8_F16
63+ case _:
64+ return intrinsic
65+
766@dataclass
867class AttentionConfig :
968 B : int
@@ -71,20 +130,11 @@ def get_lowering_config(self) -> str:
71130 + "{ "
72131 + f"workgroup = [{ ', ' .join (map (str , self .wg_tiles ))} ], "
73132 + f"reduction = [{ ', ' .join (map (str , self .reduction_tiles ))} ],"
74- + f"promote_operands = [0, 1, 2]"
133+ + f"promote_operands = [1, 2]"
75134 + " }"
76135 + f">"
77136 )
78137
79- def get_mma_schedule (self ) -> str :
80- return (
81- f"#iree_gpu.mma_schedule<"
82- + f"intrinsic = #iree_gpu.mma_layout<{ self .intrinsic } >"
83- + f", subgroup_m_count = { self .M_warp } "
84- + f", subgroup_n_count = { self .N_warp } "
85- + f">"
86- )
87-
88138 def get_translation_info (self ) -> str :
89139 llvm_func_attrs = []
90140 if self .waves_per_eu :
@@ -93,11 +143,10 @@ def get_translation_info(self) -> str:
93143 llvm_func_attrs += [f'"denormal-fp-math-f32" = "preserve-sign"' ]
94144 return (
95145 f"#iree_codegen.translation_info<"
96- + f"LLVMGPUVectorDistribute"
146+ + f"pipeline = LLVMGPUVectorDistribute"
97147 + f" workgroup_size = [{ self .N_warp * self .M_warp * 64 } ]"
98148 + f" subgroup_size = 64"
99- + f" ,{{mma_schedule = { self .get_mma_schedule ()} "
100- + f" , llvm_func_attrs = {{ { ',' .join (llvm_func_attrs )} }}"
149+ + f" , {{llvm_func_attrs = {{ { ',' .join (llvm_func_attrs )} }}"
101150 + f"}}"
102151 + f">"
103152 )
@@ -110,6 +159,26 @@ def get_compilation_info(self) -> str:
110159 + f">"
111160 )
112161
162+ def get_qk_config_info (self ) -> str :
163+ return (
164+ f"#iree_gpu.lowering_config<{{"
165+ + f"mma_kind = { get_intrinsic_string (self .intrinsic )} "
166+ + f", subgroup_m_count = { self .M_warp } "
167+ + f", subgroup_n_count = { self .N_warp } "
168+ + f", promote_operands = [1]"
169+ + f"}}>"
170+ )
171+
172+ def get_pv_config_info (self ) -> str :
173+ return (
174+ f"#iree_gpu.lowering_config<{{"
175+ + f"mma_kind = { get_intrinsic_string (get_pv_intrinsic (self .intrinsic ))} "
176+ + f", subgroup_m_count = { self .M_warp } "
177+ + f", subgroup_n_count = { self .N_warp } "
178+ + f", promote_operands = [1]"
179+ + f"}}>"
180+ )
181+
113182
114183def generate_mlir (config : AttentionConfig , tuning : Optional [TuningSpec ] = None ):
115184 shapes = f"""\
@@ -136,11 +205,11 @@ def generate_mlir(config: AttentionConfig, tuning: Optional[TuningSpec] = None):
136205func.func @main(%Q : !Q, %K : !K, %V : !V) -> !O {{
137206 %scale = arith.constant 1.0 : !dtype
138207 %empty = tensor.empty() : !O
139- %O = iree_linalg_ext.attention
208+ %O = iree_linalg_ext.attention
140209 {{ indexing_maps = [#Q, #K, #V, #S, #O]
141210 ,decomposition_config = {{
142- qk_attrs = {{attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{{promote_operands = [0, 1]}}> }},
143- pv_attrs = {{attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{{promote_operands = [1]}}> }}
211+ qk_attrs = {{attention_qk_matmul, lowering_config = { tuning . get_qk_config_info () } }},
212+ pv_attrs = {{attention_pv_matmul, lowering_config = { tuning . get_pv_config_info () } }}
144213 }}
145214 { ",compilation_info = #tuning" if tuning and config .dtype == "f16" else "" }
146215 }}
@@ -168,7 +237,7 @@ def compile_attention_config(
168237
169238 # TODO: Use different tuning specs for different configs. This is just a
170239 # general tuning config that worked well for sdxl shapes.
171- spec = TuningSpec ([1 , 128 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 32 ], 4 , 1 , "MFMA_F32_32x32x8_F16" , 2 , True )
240+ spec = TuningSpec ([1 , 128 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 32 ], 4 , 1 , IntrinsicType . VMFMA_F32_32x32x16_F16 , 2 , True )
172241 # Generate mlir content
173242 mlir_content = generate_mlir (config , spec )
174243
@@ -211,5 +280,5 @@ def compile_attention_config(
211280# Dummy test generation
212281if __name__ == "__main__" :
213282 config = AttentionConfig (20 , 4096 , 64 , 64 , 4096 , "f16" )
214- spec = TuningSpec ([1 , 128 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 32 ], 4 , 1 , "MFMA_F32_32x32x8_F16" , 2 , True )
283+ spec = TuningSpec ([1 , 128 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 32 ], 4 , 1 , IntrinsicType . VMFMA_F32_32x32x16_F16 , 2 , True )
215284 print (generate_mlir (config , spec ))
0 commit comments