Skip to content

Commit 68141c9

Browse files
authored
Add flex decoding autotune config (#4726)
Add flex decoding autotune config. Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 288599e commit 68141c9

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch._inductor.lowering
1414
import torch._inductor.kernel
1515
import torch._inductor.kernel.flex_attention as flex_attn
16-
from torch._inductor.template_heuristics import FlexConfig
16+
from torch._inductor.template_heuristics import FlexConfig, FlexDecodeConfig
1717

1818
import triton_kernels_benchmark as benchmark_suit
1919

@@ -32,10 +32,25 @@ def get_flex_attn_fwd_configs(*args, **kwargs): # pylint: disable=unused-argume
3232
return configs
3333

3434

35+
def get_flex_decode_configs(*args, **kwargs): # pylint: disable=unused-argument
36+
configs = [
37+
FlexDecodeConfig(32, 1, 2),
38+
FlexDecodeConfig(32, 1, 1),
39+
FlexDecodeConfig(32, 2, 2),
40+
FlexDecodeConfig(32, 2, 1),
41+
FlexDecodeConfig(64, 1, 2),
42+
FlexDecodeConfig(64, 1, 1),
43+
FlexDecodeConfig(64, 2, 2),
44+
FlexDecodeConfig(64, 2, 1),
45+
]
46+
return configs
47+
48+
3549
# There is a auto-tuning requirement to get the best configuration for the flex attention.
3650
# The pytorch flex attention doesn't support auto-tuning by user by default.
3751
# Overriding the get_flex_attention_fwd_configs method to provide custom configurations for auto-tuning on XPU.
3852
flex_attn.V.choices.get_flex_attention_fwd_configs = get_flex_attn_fwd_configs
53+
flex_attn.V.choices.get_flex_decode_configs = get_flex_decode_configs
3954

4055
torch._dynamo.config.recompile_limit = 100 # pylint: disable=protected-access
4156

0 commit comments

Comments
 (0)