Skip to content

Commit f737aee

Browse files
yudongsichengjunluwhitneywhtsang
authored
Use well tuned kernel options for flex attention (#4484)
Geomean speedup is 1.45x on PVC max1100. ![image](https://github.com/user-attachments/assets/69085d53-ba75-43ad-b3e8-dfb87516b47c) --------- Signed-off-by: Lu,Chengjun <[email protected]> Co-authored-by: Lu,Chengjun <[email protected]> Co-authored-by: Whitney Tsang <[email protected]>
1 parent 38a1984 commit f737aee

File tree

3 files changed

+51
-1
lines changed

3 files changed

+51
-1
lines changed

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,35 @@
99

1010
import torch
1111
import torch.nn.functional as F
12+
import torch._inductor
13+
import torch._inductor.lowering
14+
import torch._inductor.kernel
15+
import torch._inductor.kernel.flex_attention as flex_attn
16+
import torch._inductor.virtualized
1217

1318
import triton_kernels_benchmark as benchmark_suit
1419

20+
# Use TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 or uncomment the following line to print the auto-tune results.
21+
# torch._inductor.config.max_autotune_gemm = True
22+
23+
24+
def get_xpu_config(*args, **kwargs): # pylint: disable=unused-argument
25+
# BLOCK_M, BLOCK_N, num_warps, num_stages
26+
configs = [
27+
(32, 16, 4, 2),
28+
(128, 64, 16, 2),
29+
(128, 64, 8, 2),
30+
(128, 32, 16, 2),
31+
(128, 32, 8, 2),
32+
]
33+
return configs
34+
35+
36+
# There is a auto-tuning requirement to get the best configuration for the flex attention.
37+
# The pytorch flex attention doesn't support auto-tuning by user by default.
38+
# Overriding the _get_xpu_config method to provide custom configurations for auto-tuning on XPU.
39+
flex_attn._get_xpu_config = get_xpu_config # pylint: disable=protected-access
40+
1541
torch._dynamo.config.recompile_limit = 100 # pylint: disable=protected-access
1642

1743
# Compile the flex_attention function
@@ -112,7 +138,7 @@ def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provid
112138
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(torch_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
113139

114140
elif provider == 'triton':
115-
kernel_options = {'num_stages': 2, 'num_warps': 16 if D_HEAD_qk == 128 else 8, 'BLOCKS_ARE_CONTIGUOUS': True}
141+
kernel_options = {'BLOCKS_ARE_CONTIGUOUS': True}
116142
triton_fn = lambda: compiled_flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=(
117143
not H_q == H_kv), kernel_options=kernel_options)
118144
if MODE == 'bwd':

scripts/patch-pytorch.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@ echo "Applying PyTorch patches in $REPO_ROOT"
3737
# put your patch applies here
3838
apply_patch https://github.com/pytorch/pytorch/pull/143553.diff
3939
apply_patch pytorch_fp64.patch
40+
apply_patch ./patch/Patch_torch_flex_attention_for_autotune_in_benchmark.patch
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
Subject: [PATCH] Patch torch flex attention for autotune in benchmark
2+
---
3+
Index: torch/_inductor/kernel/flex_attention.py
4+
IDEA additional info:
5+
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
6+
<+>UTF-8
7+
===================================================================
8+
diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py
9+
--- a/torch/_inductor/kernel/flex_attention.py (revision 71e4cab58c04534b7608b4b01685180797271407)
10+
+++ b/torch/_inductor/kernel/flex_attention.py (date 1749737580817)
11+
@@ -1643,7 +1643,11 @@
12+
13+
choices: list[Any] = []
14+
configs: list[tuple[int, int, int, int]] = []
15+
- configs.append(_get_default_config_fwd(query))
16+
+ default_configs = _get_default_config_fwd(query)
17+
+ if isinstance(default_configs, tuple):
18+
+ configs.append(default_configs)
19+
+ else:
20+
+ configs.extend(default_configs)
21+
if config.max_autotune:
22+
configs += [
23+
(128, 64, 4, 3),

0 commit comments

Comments
 (0)