Skip to content

Commit c74b88e

Browse files
authored
[FA] Add option to specify tuning parameters (#3293)
Add option specifying tuning parameters. Users can override default parameters passing a list of options to autotune from. This way, passing `-BLOCK-M 64 32 128` would mean these values for `BLOCK_M` are used for autotuning. Also split options in two different option groups so the help string looks something like: ``` usage: flash-attention [-h] -Z Z -H H -N-CTX N_CTX -D-HEAD D_HEAD [-causal] [-backward] [-BLOCK-M BLOCK_M [BLOCK_M ...]] [-BLOCK-N BLOCK_N [BLOCK_N ...]] [-stages STAGES [STAGES ...]] [-warps WARPS [WARPS ...]] Run Intel XPU Flash-Attention implementation options: -h, --help show this help message and exit Model description: Options setting different model metaparameters -Z Z Batch size -H H Head count -N-CTX N_CTX Sequence length -D-HEAD D_HEAD Embedding dimension -causal Run causal attention -backward Run backward attention Tuning configuration: Options setting different tuning parameters -BLOCK-M BLOCK_M [BLOCK_M ...] Sizes of M -BLOCK-N BLOCK_N [BLOCK_N ...] Sizes of N -stages STAGES [STAGES ...] Numbers of stages -warps WARPS [WARPS ...] Numbers of warps ``` --------- Signed-off-by: victor-eds <[email protected]>
1 parent 0234248 commit c74b88e

File tree

1 file changed

+37
-7
lines changed

1 file changed

+37
-7
lines changed

scripts/flash_attention.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,49 @@
33
import argparse
44

55
import torch
6+
import triton
67

7-
from triton_kernels_benchmark.flash_attention_benchmark import _attention
8+
from triton_kernels_benchmark.flash_attention_benchmark import _attention, tune_attn_fwd
89

910

1011
def get_options():
1112
"""Gather CL options."""
1213
parser = argparse.ArgumentParser(prog='flash-attention', description='Run Intel XPU Flash-Attention implementation')
13-
parser.add_argument('-Z', type=int, required=True, help='Batch size')
14-
parser.add_argument('-H', type=int, required=True, help='Head count')
15-
parser.add_argument('-N-CTX', type=int, required=True, help='Sequence length')
16-
parser.add_argument('-D-HEAD', type=int, required=True, help='Embedding dimension')
17-
parser.add_argument('-causal', action='store_true', help='Run causal attention')
18-
parser.add_argument('-backward', action='store_true', help='Run backward attention')
14+
15+
model = parser.add_argument_group(title='Model description',
16+
description='Options setting different model metaparameters')
17+
model.add_argument('-Z', type=int, required=True, help='Batch size')
18+
model.add_argument('-H', type=int, required=True, help='Head count')
19+
model.add_argument('-N-CTX', type=int, required=True, help='Sequence length')
20+
model.add_argument('-D-HEAD', type=int, required=True, help='Embedding dimension')
21+
model.add_argument('-causal', action='store_true', help='Run causal attention')
22+
model.add_argument('-backward', action='store_true', help='Run backward attention')
23+
24+
config = parser.add_argument_group(title='Tuning configuration',
25+
description='Options setting different tuning parameters')
26+
config.add_argument('-BLOCK-M', action='extend', nargs='+', type=int, help='Sizes of M')
27+
config.add_argument('-BLOCK-N', action='extend', nargs='+', type=int, help='Sizes of N')
28+
config.add_argument('-stages', action='extend', nargs='+', type=int, help='Numbers of stages')
29+
config.add_argument('-warps', action='extend', nargs='+', type=int, help='Numbers of warps')
1930
return parser.parse_args()
2031

2132

33+
def get_configs(options):
34+
"""Get autotuning configurations."""
35+
bm_values = options.BLOCK_M if options.BLOCK_M else [128, 256]
36+
bn_values = options.BLOCK_N if options.BLOCK_N else [32, 64]
37+
stages_values = options.stages if options.stages else [3, 4]
38+
warps_values = options.warps if options.warps else [8, 16, 32]
39+
return [
40+
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large', 'one_matrix_per_load_for_bt': True},
41+
num_stages=s, num_warps=w)
42+
for BM in bm_values
43+
for BN in bn_values
44+
for s in stages_values
45+
for w in warps_values
46+
]
47+
48+
2249
def run(options):
2350
"""Run the XPU backend FlashAttention benchmark implementation."""
2451
dtype = torch.float16
@@ -27,6 +54,9 @@ def run(options):
2754
k = torch.randn_like(q, device='xpu', dtype=dtype, requires_grad=True)
2855
v = torch.randn_like(q, device='xpu', dtype=dtype, requires_grad=True)
2956
sm_scale = 0.125
57+
58+
tune_attn_fwd.configs = get_configs(options)
59+
3060
attention = _attention.apply
3161
triton_o = attention(q, k, v, options.causal, sm_scale)
3262
if options.backward:

0 commit comments

Comments
 (0)