1313import torch ._inductor .lowering
1414import torch ._inductor .kernel
1515import torch ._inductor .kernel .flex_attention as flex_attn
16- from torch ._inductor .template_heuristics import FlexConfig
16+ from torch ._inductor .template_heuristics import FlexConfig , FlexDecodeConfig
1717
1818import triton_kernels_benchmark as benchmark_suit
1919
@@ -32,10 +32,25 @@ def get_flex_attn_fwd_configs(*args, **kwargs): # pylint: disable=unused-argume
3232 return configs
3333
3434
35+ def get_flex_decode_configs (* args , ** kwargs ): # pylint: disable=unused-argument
36+ configs = [
37+ FlexDecodeConfig (32 , 1 , 2 ),
38+ FlexDecodeConfig (32 , 1 , 1 ),
39+ FlexDecodeConfig (32 , 2 , 2 ),
40+ FlexDecodeConfig (32 , 2 , 1 ),
41+ FlexDecodeConfig (64 , 1 , 2 ),
42+ FlexDecodeConfig (64 , 1 , 1 ),
43+ FlexDecodeConfig (64 , 2 , 2 ),
44+ FlexDecodeConfig (64 , 2 , 1 ),
45+ ]
46+ return configs
47+
48+
3549# There is a auto-tuning requirement to get the best configuration for the flex attention.
3650# The pytorch flex attention doesn't support auto-tuning by user by default.
3751# Overriding the get_flex_attention_fwd_configs method to provide custom configurations for auto-tuning on XPU.
3852flex_attn .V .choices .get_flex_attention_fwd_configs = get_flex_attn_fwd_configs
53+ flex_attn .V .choices .get_flex_decode_configs = get_flex_decode_configs
3954
4055torch ._dynamo .config .recompile_limit = 100 # pylint: disable=protected-access
4156
0 commit comments