Skip to content

Commit 0a9526a

Browse files
committed
Add TK Wave kernels to attention benchmark
Signed-off-by: Stanley Winata <[email protected]>
1 parent 87c0c8c commit 0a9526a

File tree

4 files changed

+280
-7
lines changed

4 files changed

+280
-7
lines changed

.github/workflows/run_bench.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ jobs:
4747
source bench_venv/bin/activate
4848
python attentionbench/attention_bench.py
4949
50+
- name: TK Attention
51+
run: |
52+
source bench_venv/bin/activate
53+
python attentionbench/attention_bench.py --tk
54+
5055
- name: TK GEMM
5156
run: |
5257
source bench_venv/bin/activate
@@ -66,6 +71,8 @@ jobs:
6671
python convbench/conv_bench.py --roofline results/iree_conv_tk.csv --plot results/iree_conv_tk_f16.png --dtype f16
6772
python convbench/conv_bench.py --roofline results/iree_attention.csv --plot results/iree_attention_fp16.png --dtype f16
6873
python convbench/conv_bench.py --roofline results/iree_attention.csv --plot results/iree_attention_fp8.png --dtype f8E4M3FNUZ
74+
python convbench/conv_bench.py --roofline results/iree_attention_tk.csv --plot results/iree_attention_tk_fp16.png --dtype f16
75+
python convbench/conv_bench.py --roofline results/iree_attention_tk.csv --plot results/iree_attention_tk_fp8.png --dtype f8E4M3FNUZ
6976
python convbench/conv_bench.py --roofline results/iree_gemm.csv --plot results/iree_gemm.png
7077
python convbench/conv_bench.py --roofline results/iree_gemm_tk.csv --plot results/iree_gemm_tk.png
7178
python convbench/conv_bench.py --roofline results/iree_gemm.csv,results/iree_gemm_tk.csv,results/iree_attention.csv,results/iree_conv.csv,results/iree_conv_tk.csv --plot results/combined.png

attentionbench/attention_bench.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,17 @@
1010
from utils import *
1111
from attention_utils import *
1212
from problems import get_attention_configs
13+
from wave_attention_utils import compile_wave_attention_config
1314

1415

15-
def compile_attention(tag, config, kernel_dir, vmfb_dir):
16+
def compile_attention_iree(tag, config, kernel_dir, vmfb_dir):
1617
mlir_file, vmfb_file = compile_attention_config(config, kernel_dir, vmfb_dir)
1718
return (tag, config, mlir_file, vmfb_file)
1819

20+
def compile_attention_wave(tag, config, kernel_dir, vmfb_dir):
21+
mlir_file, vmfb_file = compile_wave_attention_config(config, kernel_dir, vmfb_dir)
22+
return (tag, config, mlir_file, vmfb_file)
23+
1924

2025
if __name__ == "__main__":
2126
parser = argparse.ArgumentParser(description="Config file updater.")
@@ -36,6 +41,7 @@ def compile_attention(tag, config, kernel_dir, vmfb_dir):
3641
parser.add_argument("--batch", help="roofline on certain batch", type=int, default=None)
3742
parser.add_argument("--dtype", help="roofline on certain dtype", default=None)
3843
parser.add_argument("--model", help="roofline on certain model", default=None)
44+
parser.add_argument('--tk', help="Run conv kernels using Wave Kernels", action=argparse.BooleanOptionalAction)
3945

4046
args = parser.parse_args()
4147
logging.basicConfig(level=args.log_level)
@@ -63,6 +69,7 @@ def compile_attention(tag, config, kernel_dir, vmfb_dir):
6369
compile_args = itertools.starmap(
6470
lambda tag, config: (tag, config, kernel_dir, vmfb_dir), configs
6571
)
72+
compile_attention = compile_attention_wave if args.tk else compile_attention_iree
6673
with Pool(num_cpus) as pool:
6774
compilation_results = list(tqdm(pool.starmap(compile_attention, list(compile_args))))
6875

@@ -80,7 +87,8 @@ def compile_attention(tag, config, kernel_dir, vmfb_dir):
8087

8188
results = []
8289
index = 0
83-
output_csv = "results/iree_attention.csv"
90+
output_csv = "results/iree_attention_tk.csv" if args.tk else "results/iree_attention.csv"
91+
entrypoint = "isolated_benchmark" if args.tk else "main"
8492
csv_dir = os.path.dirname(output_csv)
8593
if not os.path.exists(csv_dir):
8694
os.makedirs(csv_dir)
@@ -98,13 +106,17 @@ def compile_attention(tag, config, kernel_dir, vmfb_dir):
98106
f"--device={device}",
99107
"--device_allocator=caching",
100108
f"--module={vmfb_filename}",
101-
"--function=main",
109+
f"--function={entrypoint}",
110+
"--benchmark_repetitions=3",
102111
f"--input={query_shape}",
103112
f"--input={key_shape}",
104113
f"--input={value_shape}",
105-
"--benchmark_repetitions=3",
106114
]
107115

116+
if args.tk:
117+
out_shape = config.get_output_shape()
118+
exec_args.append(f"--input={out_shape}")
119+
108120
# iree benchmark kernels
109121
ret_value, cmd_out, cmd_err = run_iree_command(exec_args)
110122
ok = ret_value == 0

attentionbench/attention_utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,16 @@ def get_pv_intrinsic(intrinsic: IntrinsicType):
6363
case _:
6464
return intrinsic
6565

66+
def get_32_bit_type(input_type: str):
67+
assert isinstance(input_type, str)
68+
match input_type[0]:
69+
case "f":
70+
return "f32"
71+
case "i":
72+
return "i32"
73+
case _:
74+
raise NotImplementedError("Unexpected type to obtain 32 bit type on attention utils.")
75+
6676
@dataclass
6777
class AttentionConfig:
6878
B: int
@@ -82,10 +92,10 @@ def get_key_shape(self) -> str:
8292
return f"{self.B}x{self.K2}x{self.K1}x{self.dtype}"
8393

8494
def get_value_shape(self) -> str:
85-
return f"{self.B}x{self.K2}x{self.N}x{self.dtype}"
95+
return f"{self.B}x{self.N}x{self.K2}x{self.dtype}"
8696

8797
def get_output_shape(self) -> str:
88-
return f"{self.B}x{self.M}x{self.N}x{self.dtype}"
98+
return f"{self.B}x{self.M}x{self.N}x{get_32_bit_type(self.dtype)}"
8999

90100
def get_byte_count(self) -> int:
91101
dtype_bits_map = {
@@ -198,7 +208,7 @@ def generate_mlir(config: AttentionConfig, tuning: Optional[TuningSpec] = None):
198208
attn_kernel = f"""
199209
#Q = affine_map<(b, m, n, k1, k2) -> (b, m, k1)>
200210
#K = affine_map<(b, m, n, k1, k2) -> (b, k2, k1)>
201-
#V = affine_map<(b, m, n, k1, k2) -> (b, k2, n)>
211+
#V = affine_map<(b, m, n, k1, k2) -> (b, n, k2)>
202212
#S = affine_map<(b, m, n, k1, k2) -> ()>
203213
#O = affine_map<(b, m, n, k1, k2) -> (b, m, n)>
204214
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
from utils import *
2+
from dataclasses import dataclass
3+
from pathlib import Path
4+
from typing import Optional
5+
from attention_utils import AttentionConfig
6+
import traceback
7+
8+
try:
9+
import iree.turbine.kernel as tk
10+
import iree.turbine.kernel.lang as tkl
11+
import iree.turbine.kernel.wave as tkw
12+
from iree.turbine.kernel.lang.global_symbols import *
13+
from iree.turbine.kernel.wave.constraints import MMAType
14+
from iree.turbine.kernel.wave.utils import (
15+
get_mfma_load_elems_per_thread,
16+
get_mfma_store_elems_per_thread,
17+
)
18+
except ImportError:
19+
TURBINE_AVAILABLE = False
20+
else:
21+
TURBINE_AVAILABLE = True
22+
23+
@dataclass
24+
class AttentionShape:
25+
num_query_heads: int
26+
num_kv_heads: int
27+
head_size: int
28+
head_size_kv: int
29+
# -----------------------
30+
# Prefill specific
31+
num_seqs: Optional[int] = None
32+
max_seq_len: Optional[int] = None
33+
total_seq_len: Optional[int] = None
34+
# -----------------------
35+
# Vanilla attention
36+
query_seq_len: Optional[int] = None
37+
kv_seq_len: Optional[int] = None
38+
39+
def get_vanilla_attention_kernel(
40+
shape: AttentionShape, mfma_variant: MMAType, dynamic_dims: bool, input_dtype: "dtype"
41+
):
42+
# Input sizes
43+
B = tkl.sym.B
44+
M = tkl.sym.M
45+
N = tkl.sym.N
46+
K1 = tkl.sym.K1
47+
K2 = tkl.sym.K2
48+
# Workgroup tile sizes
49+
BLOCK_B = tkl.sym.BLOCK_B
50+
BLOCK_M = tkl.sym.BLOCK_M
51+
BLOCK_N = tkl.sym.BLOCK_N
52+
BLOCK_K2 = tkl.sym.BLOCK_K2
53+
# Address space (for GPU, shared(1) or global(0))
54+
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
55+
# Other hyperparameters
56+
LOAD_ELEMS_PER_THREAD_QK = index_symbol("LOAD_ELEMS_PER_THREAD_QK")
57+
LOAD_ELEMS_PER_THREAD_PV = index_symbol("LOAD_ELEMS_PER_THREAD_PV")
58+
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD
59+
60+
# Expose user-constraints
61+
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
62+
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
63+
constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)]
64+
constraints += [tkw.TilingConstraint(K2, BLOCK_K2)]
65+
constraints += [tkw.WaveConstraint(M, BLOCK_M / 4)]
66+
constraints += [tkw.WaveConstraint(N, BLOCK_N / 1)]
67+
68+
if mfma_variant[1] == MMAType.F32_16x16x16_F16 or mfma_variant[1] == MMAType.F32_16x16x32_F8:
69+
Mvec = 16
70+
Nvec = 16
71+
if mfma_variant[1] == MMAType.F32_32x32x8_F16 or mfma_variant[1] == MMAType.F32_32x32x16_F8:
72+
Mvec = 32
73+
Nvec = 32
74+
75+
constraints += [
76+
tkw.HardwareConstraint(
77+
threads_per_wave=64,
78+
waves_per_block=(4, 1, 1),
79+
mma_type=mfma_variant[1],
80+
vector_shapes={B: 0, M: Mvec, N: Nvec},
81+
)
82+
]
83+
84+
if dynamic_dims:
85+
constraints += [tkw.Assumption(K2 > BLOCK_K2 * 4)]
86+
87+
i = tkw.IndexMapping.iterator(0)
88+
j = tkw.IndexMapping.iterator(1)
89+
k = tkw.IndexMapping.iterator(2)
90+
mapping = tkw.IndexMapping(
91+
num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j}
92+
)
93+
94+
@tkw.wave(constraints)
95+
def base_attention(
96+
q: tkl.Memory[B, M, K1, GLOBAL_ADDRESS_SPACE, input_dtype],
97+
k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, input_dtype],
98+
v: tkl.Memory[B, N, K2, ADDRESS_SPACE, input_dtype],
99+
c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
100+
):
101+
c_reg = tkl.Register[B, N, M, tkl.f32](0.0)
102+
init_sum = tkl.Register[B, M, tkl.f32](0.0)
103+
init_max = tkl.Register[B, M, tkl.f32](-1e6)
104+
105+
# This microkernel encodes the fact that if the reduction
106+
# dimension were tiled, then we would need to materialize a loop.
107+
@tkw.reduction(K2, init_args=[init_max, init_sum, c_reg])
108+
def repeat(
109+
partial_max: tkl.Register[B, M, tkl.f32],
110+
partial_sum: tkl.Register[B, M, tkl.f32],
111+
acc: tkl.Register[B, N, M, tkl.f32],
112+
):
113+
imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0)
114+
q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK)
115+
k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK)
116+
inner_acc = tkw.mma(k_reg, q_reg, imm_reg, mfma_variant[0])
117+
x_j = tkw.permute(inner_acc, target_shape=[B, M, K2])
118+
m_j = tkw.max(x_j, partial_max, dim=K2)
119+
e_delta_max = tkw.exp2(partial_max - m_j)
120+
e_delta = tkw.exp2(x_j - m_j)
121+
e_init = partial_sum * e_delta_max
122+
d_j = tkw.sum(e_delta, e_init, dim=K2)
123+
imm_f16 = tkw.cast(e_delta, input_dtype)
124+
v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD_PV)
125+
new_acc = acc * e_delta_max
126+
acc = tkw.mma(v_reg, imm_f16, new_acc)
127+
return m_j, d_j, acc
128+
129+
# repeat represents the results of the loop
130+
res_max, res_sum, res_mm = repeat
131+
reciprocal_sum = tkw.reciprocal(res_sum)
132+
res = res_mm * reciprocal_sum
133+
tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD)
134+
135+
hyperparams = {
136+
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
137+
LOAD_ELEMS_PER_THREAD_QK: get_mfma_load_elems_per_thread(mfma_variant[0]),
138+
LOAD_ELEMS_PER_THREAD_PV: get_mfma_load_elems_per_thread(mfma_variant[1]),
139+
STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant[1]),
140+
BLOCK_B: 1,
141+
BLOCK_M: 128,
142+
BLOCK_N: 64,
143+
BLOCK_K2: 64,
144+
B: shape.num_query_heads,
145+
M: shape.query_seq_len,
146+
N: shape.head_size_kv,
147+
K1: shape.head_size,
148+
K2: shape.kv_seq_len,
149+
}
150+
151+
dynamic_symbols = []
152+
dynamic_symbols_map = {}
153+
if dynamic_dims:
154+
dynamic_symbols_map[M] = hyperparams[M]
155+
dynamic_symbols_map[N] = hyperparams[N]
156+
dynamic_symbols_map[B] = hyperparams[B]
157+
dynamic_symbols_map[K2] = hyperparams[K2]
158+
dynamic_symbols.append(M)
159+
dynamic_symbols.append(N)
160+
dynamic_symbols.append(B)
161+
dynamic_symbols.append(K2)
162+
del hyperparams[M]
163+
del hyperparams[N]
164+
del hyperparams[B]
165+
del hyperparams[K2]
166+
167+
return base_attention, hyperparams, dynamic_symbols, dynamic_symbols_map
168+
169+
170+
def compile_wave_attention_config(
171+
config: AttentionConfig, kernel_dir: Path, vmfb_dir: Path
172+
) -> tuple[Path, Optional[Path]]:
173+
if not TURBINE_AVAILABLE:
174+
raise ValueError("iree.turbine package is not available")
175+
176+
mlir_file = kernel_dir / (config.get_name() + ".mlir")
177+
vmfb_file = vmfb_dir / (config.get_name() + ".vmfb")
178+
179+
try:
180+
_compile_attention(config, mlir_file, vmfb_file)
181+
except Exception as e:
182+
error_file = vmfb_dir / (config.get_name() + "_error.txt")
183+
print(f"Failed to compile {config.get_name()}. Error dumped in {error_file}")
184+
with open(error_file, "w") as f:
185+
f.write(str(e))
186+
f.write(traceback.format_exc())
187+
return mlir_file, None, None
188+
189+
return mlir_file, vmfb_file
190+
191+
192+
def _convert_dtype(dtype: str):
193+
dtypes = {
194+
"i8": tkl.i8,
195+
"i16": tkl.i16,
196+
"i32": tkl.i32,
197+
"i64": tkl.i64,
198+
"f8E4M3FNUZ": tkl.f8e4m3fnuz,
199+
"f16": tkl.f16,
200+
"f32": tkl.f32,
201+
"f64": tkl.f64,
202+
"bf16": tkl.bf16,
203+
}
204+
return dtypes[dtype]
205+
206+
207+
def _compile_attention(config: AttentionConfig, mlir_file: Path, vmfb_file: Path):
208+
shape = AttentionShape(
209+
num_query_heads=config.B,
210+
num_kv_heads=config.B,
211+
query_seq_len=config.M,
212+
head_size_kv=config.N,
213+
head_size=config.K1,
214+
kv_seq_len=config.K2,
215+
)
216+
217+
input_dtype = _convert_dtype(config.dtype)
218+
if input_dtype == tkl.f16:
219+
mfma_variant = (MMAType.F32_32x32x8_F16, MMAType.F32_32x32x8_F16)
220+
elif input_dtype == tkl.f8e4m3fnuz:
221+
mfma_variant = (MMAType.F32_32x32x16_F8, MMAType.F32_32x32x16_F8)
222+
else:
223+
raise NotImplementedError(f"Got {config.dtype}, TK attention currently only support f8E4M3FNUZ and f16.")
224+
225+
base_attention, hyperparams, _, _ = get_vanilla_attention_kernel(
226+
shape, mfma_variant, False, input_dtype
227+
)
228+
229+
# config = get_default_run_config()
230+
config = {"backend": "rocm", "device": "hip", "target": "gfx942"}
231+
232+
with tk.gen.TestLaunchContext(
233+
hyperparams,
234+
canonicalize=True,
235+
create_vmfb_file=vmfb_file,
236+
run_config=config,
237+
schedule=False,
238+
inline=False,
239+
):
240+
mod = base_attention().module_op # This will generate vmfb file
241+
with open(mlir_file, "w") as f:
242+
f.write(str(mod))
243+
244+
print(f"Successfully compiled to {vmfb_file}")

0 commit comments

Comments
 (0)