|
9 | 9 |
|
10 | 10 | import torch
|
11 | 11 | import torch.nn.functional as F
|
| 12 | +import torch._inductor |
| 13 | +import torch._inductor.lowering |
| 14 | +import torch._inductor.kernel |
| 15 | +import torch._inductor.kernel.flex_attention as flex_attn |
| 16 | +import torch._inductor.virtualized |
12 | 17 |
|
13 | 18 | import triton_kernels_benchmark as benchmark_suit
|
14 | 19 |
|
| 20 | +# Use TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 or uncomment the following line to print the auto-tune results. |
| 21 | +# torch._inductor.config.max_autotune_gemm = True |
| 22 | + |
| 23 | + |
| 24 | +def get_xpu_config(*args, **kwargs): # pylint: disable=unused-argument |
| 25 | + # BLOCK_M, BLOCK_N, num_warps, num_stages |
| 26 | + configs = [ |
| 27 | + (32, 16, 4, 2), |
| 28 | + (128, 64, 16, 2), |
| 29 | + (128, 64, 8, 2), |
| 30 | + (128, 32, 16, 2), |
| 31 | + (128, 32, 8, 2), |
| 32 | + ] |
| 33 | + return configs |
| 34 | + |
| 35 | + |
| 36 | +# There is a auto-tuning requirement to get the best configuration for the flex attention. |
| 37 | +# The pytorch flex attention doesn't support auto-tuning by user by default. |
| 38 | +# Overriding the _get_xpu_config method to provide custom configurations for auto-tuning on XPU. |
| 39 | +flex_attn._get_xpu_config = get_xpu_config # pylint: disable=protected-access |
| 40 | + |
15 | 41 | torch._dynamo.config.recompile_limit = 100 # pylint: disable=protected-access
|
16 | 42 |
|
17 | 43 | # Compile the flex_attention function
|
@@ -112,7 +138,7 @@ def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provid
|
112 | 138 | _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(torch_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
|
113 | 139 |
|
114 | 140 | elif provider == 'triton':
|
115 |
| - kernel_options = {'num_stages': 2, 'num_warps': 16 if D_HEAD_qk == 128 else 8, 'BLOCKS_ARE_CONTIGUOUS': True} |
| 141 | + kernel_options = {'BLOCKS_ARE_CONTIGUOUS': True} |
116 | 142 | triton_fn = lambda: compiled_flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=(
|
117 | 143 | not H_q == H_kv), kernel_options=kernel_options)
|
118 | 144 | if MODE == 'bwd':
|
|
0 commit comments