99
1010import torch
1111import torch .nn .functional as F
12- import torch ._inductor .template_heuristics as inductor_heuristics
12+ import torch ._inductor
13+ import torch ._inductor .lowering
14+ import torch ._inductor .kernel
15+ import torch ._inductor .kernel .flex_attention as flex_attn
1316from torch ._inductor .template_heuristics import FlexConfig
1417
1518import triton_kernels_benchmark as benchmark_suit
1619
1720# Use TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 or uncomment the following line to print the auto-tune results.
1821# torch._inductor.config.max_autotune_gemm = True
1922
20- old_get_flex_attn_fwd_configs = inductor_heuristics .XPUConfigHeuristic .get_flex_attn_fwd_configs
21-
2223
2324def get_flex_attn_fwd_configs (* args , ** kwargs ): # pylint: disable=unused-argument
24- configs = old_get_flex_attn_fwd_configs (* args , ** kwargs )
25- # Add our own configurations for FlexAttention forward pass.
26- # BLOCK_M, BLOCK_N, num_stages, num_warps
27- configs += [
25+ configs = [
2826 FlexConfig (32 , 16 , 2 , 4 ),
2927 FlexConfig (128 , 64 , 2 , 16 ),
3028 FlexConfig (128 , 64 , 2 , 8 ),
@@ -36,8 +34,8 @@ def get_flex_attn_fwd_configs(*args, **kwargs): # pylint: disable=unused-argume
3634
3735# There is a auto-tuning requirement to get the best configuration for the flex attention.
3836# The pytorch flex attention doesn't support auto-tuning by user by default.
39- # Overriding the _get_xpu_config method to provide custom configurations for auto-tuning on XPU.
40- inductor_heuristics . XPUConfigHeuristic . get_flex_attn_fwd_configs = get_flex_attn_fwd_configs # pylint: disable=protected-access
37+ # Overriding the get_flex_attention_fwd_configs method to provide custom configurations for auto-tuning on XPU.
38+ flex_attn . V . choices . get_flex_attention_fwd_configs = get_flex_attn_fwd_configs
4139
4240torch ._dynamo .config .recompile_limit = 100 # pylint: disable=protected-access
4341
0 commit comments