Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .github/workflows/run_bench.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ jobs:
source bench_venv/bin/activate
python attentionbench/attention_bench.py

- name: TK Attention
run: |
source bench_venv/bin/activate
python attentionbench/attention_bench.py --tk

- name: TK GEMM
run: |
source bench_venv/bin/activate
Expand All @@ -66,9 +71,11 @@ jobs:
python convbench/conv_bench.py --roofline results/iree_conv_tk.csv --plot results/iree_conv_tk_f16.png --dtype f16
python convbench/conv_bench.py --roofline results/iree_attention.csv --plot results/iree_attention_fp16.png --dtype f16
python convbench/conv_bench.py --roofline results/iree_attention.csv --plot results/iree_attention_fp8.png --dtype f8E4M3FNUZ
python convbench/conv_bench.py --roofline results/iree_attention_tk.csv --plot results/iree_attention_tk_fp16.png --dtype f16
python convbench/conv_bench.py --roofline results/iree_attention_tk.csv --plot results/iree_attention_tk_fp8.png --dtype f8E4M3FNUZ
python convbench/conv_bench.py --roofline results/iree_gemm.csv --plot results/iree_gemm.png
python convbench/conv_bench.py --roofline results/iree_gemm_tk.csv --plot results/iree_gemm_tk.png
python convbench/conv_bench.py --roofline results/iree_gemm.csv,results/iree_gemm_tk.csv,results/iree_attention.csv,results/iree_conv.csv,results/iree_conv_tk.csv --plot results/combined.png
python convbench/conv_bench.py --roofline results/iree_gemm.csv,results/iree_gemm_tk.csv,results/iree_attention.csv,results/iree_attention_tk.csv,results/iree_conv.csv,results/iree_conv_tk.csv --plot results/combined.png

- name: Upload benchmark results
uses: actions/upload-artifact@v4
Expand Down
20 changes: 16 additions & 4 deletions attentionbench/attention_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,17 @@
from utils import *
from attention_utils import *
from problems import get_attention_configs
from wave_attention_utils import compile_wave_attention_config


def compile_attention(tag, config, kernel_dir, vmfb_dir):
def compile_attention_iree(tag, config, kernel_dir, vmfb_dir):
mlir_file, vmfb_file = compile_attention_config(config, kernel_dir, vmfb_dir)
return (tag, config, mlir_file, vmfb_file)

def compile_attention_wave(tag, config, kernel_dir, vmfb_dir):
mlir_file, vmfb_file = compile_wave_attention_config(config, kernel_dir, vmfb_dir)
return (tag, config, mlir_file, vmfb_file)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Config file updater.")
Expand All @@ -36,6 +41,7 @@ def compile_attention(tag, config, kernel_dir, vmfb_dir):
parser.add_argument("--batch", help="roofline on certain batch", type=int, default=None)
parser.add_argument("--dtype", help="roofline on certain dtype", default=None)
parser.add_argument("--model", help="roofline on certain model", default=None)
parser.add_argument('--tk', help="Run conv kernels using Wave Kernels", action=argparse.BooleanOptionalAction)

args = parser.parse_args()
logging.basicConfig(level=args.log_level)
Expand Down Expand Up @@ -63,6 +69,7 @@ def compile_attention(tag, config, kernel_dir, vmfb_dir):
compile_args = itertools.starmap(
lambda tag, config: (tag, config, kernel_dir, vmfb_dir), configs
)
compile_attention = compile_attention_wave if args.tk else compile_attention_iree
with Pool(num_cpus) as pool:
compilation_results = list(tqdm(pool.starmap(compile_attention, list(compile_args))))

Expand All @@ -80,7 +87,8 @@ def compile_attention(tag, config, kernel_dir, vmfb_dir):

results = []
index = 0
output_csv = "results/iree_attention.csv"
output_csv = "results/iree_attention_tk.csv" if args.tk else "results/iree_attention.csv"
entrypoint = "isolated_benchmark" if args.tk else "main"
csv_dir = os.path.dirname(output_csv)
if not os.path.exists(csv_dir):
os.makedirs(csv_dir)
Expand All @@ -98,13 +106,17 @@ def compile_attention(tag, config, kernel_dir, vmfb_dir):
f"--device={device}",
"--device_allocator=caching",
f"--module={vmfb_filename}",
"--function=main",
f"--function={entrypoint}",
"--benchmark_repetitions=3",
f"--input={query_shape}",
f"--input={key_shape}",
f"--input={value_shape}",
"--benchmark_repetitions=3",
]

if args.tk:
out_shape = config.get_output_shape()
exec_args.append(f"--input={out_shape}")

# iree benchmark kernels
ret_value, cmd_out, cmd_err = run_iree_command(exec_args)
ok = ret_value == 0
Expand Down
16 changes: 13 additions & 3 deletions attentionbench/attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ def get_pv_intrinsic(intrinsic: IntrinsicType):
case _:
return intrinsic

def get_32_bit_type(input_type: str):
assert isinstance(input_type, str)
match input_type[0]:
case "f":
return "f32"
case "i":
return "i32"
case _:
raise NotImplementedError("Unexpected type to obtain 32 bit type on attention utils.")

@dataclass
class AttentionConfig:
B: int
Expand All @@ -82,10 +92,10 @@ def get_key_shape(self) -> str:
return f"{self.B}x{self.K2}x{self.K1}x{self.dtype}"

def get_value_shape(self) -> str:
return f"{self.B}x{self.K2}x{self.N}x{self.dtype}"
return f"{self.B}x{self.N}x{self.K2}x{self.dtype}"

def get_output_shape(self) -> str:
return f"{self.B}x{self.M}x{self.N}x{self.dtype}"
return f"{self.B}x{self.M}x{self.N}x{get_32_bit_type(self.dtype)}"

def get_byte_count(self) -> int:
dtype_bits_map = {
Expand Down Expand Up @@ -198,7 +208,7 @@ def generate_mlir(config: AttentionConfig, tuning: Optional[TuningSpec] = None):
attn_kernel = f"""
#Q = affine_map<(b, m, n, k1, k2) -> (b, m, k1)>
#K = affine_map<(b, m, n, k1, k2) -> (b, k2, k1)>
#V = affine_map<(b, m, n, k1, k2) -> (b, k2, n)>
#V = affine_map<(b, m, n, k1, k2) -> (b, n, k2)>
#S = affine_map<(b, m, n, k1, k2) -> ()>
#O = affine_map<(b, m, n, k1, k2) -> (b, m, n)>

Expand Down
244 changes: 244 additions & 0 deletions attentionbench/wave_attention_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
from utils import *
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from attention_utils import AttentionConfig
import traceback

try:
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel.wave.constraints import MMAType
from iree.turbine.kernel.wave.utils import (
get_mfma_load_elems_per_thread,
get_mfma_store_elems_per_thread,
)
except ImportError:
TURBINE_AVAILABLE = False
else:
TURBINE_AVAILABLE = True

@dataclass
class AttentionShape:
num_query_heads: int
num_kv_heads: int
head_size: int
head_size_kv: int
# -----------------------
# Prefill specific
num_seqs: Optional[int] = None
max_seq_len: Optional[int] = None
total_seq_len: Optional[int] = None
# -----------------------
# Vanilla attention
query_seq_len: Optional[int] = None
kv_seq_len: Optional[int] = None

def get_vanilla_attention_kernel(
shape: AttentionShape, mfma_variant: MMAType, dynamic_dims: bool, input_dtype: "dtype"
):
# Input sizes
B = tkl.sym.B
M = tkl.sym.M
N = tkl.sym.N
K1 = tkl.sym.K1
K2 = tkl.sym.K2
# Workgroup tile sizes
BLOCK_B = tkl.sym.BLOCK_B
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K2 = tkl.sym.BLOCK_K2
# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
# Other hyperparameters
LOAD_ELEMS_PER_THREAD_QK = index_symbol("LOAD_ELEMS_PER_THREAD_QK")
LOAD_ELEMS_PER_THREAD_PV = index_symbol("LOAD_ELEMS_PER_THREAD_PV")
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD

# Expose user-constraints
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)]
constraints += [tkw.TilingConstraint(K2, BLOCK_K2)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 4)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 1)]

if mfma_variant[1] == MMAType.F32_16x16x16_F16 or mfma_variant[1] == MMAType.F32_16x16x32_F8:
Mvec = 16
Nvec = 16
if mfma_variant[1] == MMAType.F32_32x32x8_F16 or mfma_variant[1] == MMAType.F32_32x32x16_F8:
Mvec = 32
Nvec = 32

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(4, 1, 1),
mma_type=mfma_variant[1],
vector_shapes={B: 0, M: Mvec, N: Nvec},
)
]

if dynamic_dims:
constraints += [tkw.Assumption(K2 > BLOCK_K2 * 4)]

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)
k = tkw.IndexMapping.iterator(2)
mapping = tkw.IndexMapping(
num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j}
)

@tkw.wave(constraints)
def base_attention(
q: tkl.Memory[B, M, K1, GLOBAL_ADDRESS_SPACE, input_dtype],
k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, input_dtype],
v: tkl.Memory[B, N, K2, ADDRESS_SPACE, input_dtype],
c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[B, N, M, tkl.f32](0.0)
init_sum = tkl.Register[B, M, tkl.f32](0.0)
init_max = tkl.Register[B, M, tkl.f32](-1e6)

# This microkernel encodes the fact that if the reduction
# dimension were tiled, then we would need to materialize a loop.
@tkw.reduction(K2, init_args=[init_max, init_sum, c_reg])
def repeat(
partial_max: tkl.Register[B, M, tkl.f32],
partial_sum: tkl.Register[B, M, tkl.f32],
acc: tkl.Register[B, N, M, tkl.f32],
):
imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0)
q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK)
k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD_QK)
inner_acc = tkw.mma(k_reg, q_reg, imm_reg, mfma_variant[0])
x_j = tkw.permute(inner_acc, target_shape=[B, M, K2])
m_j = tkw.max(x_j, partial_max, dim=K2)
e_delta_max = tkw.exp2(partial_max - m_j)
e_delta = tkw.exp2(x_j - m_j)
e_init = partial_sum * e_delta_max
d_j = tkw.sum(e_delta, e_init, dim=K2)
imm_f16 = tkw.cast(e_delta, input_dtype)
v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD_PV)
new_acc = acc * e_delta_max
acc = tkw.mma(v_reg, imm_f16, new_acc)
return m_j, d_j, acc

# repeat represents the results of the loop
res_max, res_sum, res_mm = repeat
reciprocal_sum = tkw.reciprocal(res_sum)
res = res_mm * reciprocal_sum
tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD)

hyperparams = {
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
LOAD_ELEMS_PER_THREAD_QK: get_mfma_load_elems_per_thread(mfma_variant[0]),
LOAD_ELEMS_PER_THREAD_PV: get_mfma_load_elems_per_thread(mfma_variant[1]),
STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant[1]),
BLOCK_B: 1,
BLOCK_M: 128,
BLOCK_N: 64,
BLOCK_K2: 64,
B: shape.num_query_heads,
M: shape.query_seq_len,
N: shape.head_size_kv,
K1: shape.head_size,
K2: shape.kv_seq_len,
}

dynamic_symbols = []
dynamic_symbols_map = {}
if dynamic_dims:
dynamic_symbols_map[M] = hyperparams[M]
dynamic_symbols_map[N] = hyperparams[N]
dynamic_symbols_map[B] = hyperparams[B]
dynamic_symbols_map[K2] = hyperparams[K2]
dynamic_symbols.append(M)
dynamic_symbols.append(N)
dynamic_symbols.append(B)
dynamic_symbols.append(K2)
del hyperparams[M]
del hyperparams[N]
del hyperparams[B]
del hyperparams[K2]

return base_attention, hyperparams, dynamic_symbols, dynamic_symbols_map


def compile_wave_attention_config(
config: AttentionConfig, kernel_dir: Path, vmfb_dir: Path
) -> tuple[Path, Optional[Path]]:
if not TURBINE_AVAILABLE:
raise ValueError("iree.turbine package is not available")

mlir_file = kernel_dir / (config.get_name() + ".mlir")
vmfb_file = vmfb_dir / (config.get_name() + ".vmfb")

try:
_compile_attention(config, mlir_file, vmfb_file)
except Exception as e:
error_file = vmfb_dir / (config.get_name() + "_error.txt")
print(f"Failed to compile {config.get_name()}. Error dumped in {error_file}")
with open(error_file, "w") as f:
f.write(str(e))
f.write(traceback.format_exc())
return mlir_file, None, None

return mlir_file, vmfb_file


def _convert_dtype(dtype: str):
dtypes = {
"i8": tkl.i8,
"i16": tkl.i16,
"i32": tkl.i32,
"i64": tkl.i64,
"f8E4M3FNUZ": tkl.f8e4m3fnuz,
"f16": tkl.f16,
"f32": tkl.f32,
"f64": tkl.f64,
"bf16": tkl.bf16,
}
return dtypes[dtype]


def _compile_attention(config: AttentionConfig, mlir_file: Path, vmfb_file: Path):
shape = AttentionShape(
num_query_heads=config.B,
num_kv_heads=config.B,
query_seq_len=config.M,
head_size_kv=config.N,
head_size=config.K1,
kv_seq_len=config.K2,
)

input_dtype = _convert_dtype(config.dtype)
if input_dtype == tkl.f16:
mfma_variant = (MMAType.F32_32x32x8_F16, MMAType.F32_32x32x8_F16)
elif input_dtype == tkl.f8e4m3fnuz:
mfma_variant = (MMAType.F32_32x32x16_F8, MMAType.F32_32x32x16_F8)
else:
raise NotImplementedError(f"Got {config.dtype}, TK attention currently only support f8E4M3FNUZ and f16.")

base_attention, hyperparams, _, _ = get_vanilla_attention_kernel(
shape, mfma_variant, False, input_dtype
)

# config = get_default_run_config()
config = {"backend": "rocm", "device": "hip", "target": "gfx942"}

with tk.gen.TestLaunchContext(
hyperparams,
canonicalize=True,
create_vmfb_file=vmfb_file,
run_config=config,
schedule=False,
inline=False,
):
mod = base_attention().module_op # This will generate vmfb file
with open(mlir_file, "w") as f:
f.write(str(mod))

print(f"Successfully compiled to {vmfb_file}")
2 changes: 1 addition & 1 deletion common_tools/utils/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def roofline(results=None, out=None, batch=None, dtype=None, model=None, **kwarg
if model:
data = filter_model(data, model)
if len(data) == 0:
raise ValueError("No data to plot. If you set filters, there were no kernels with the target config")
raise ValueError(f"No data to plot from file {result_file} with filter: {batch}, {dtype}, {model}. If you set filters, there were no kernels with the target config")
x = [item['arithmetic_intensity'] for item in data]
y = [item['tflops'] for item in data]

Expand Down
Loading