Skip to content

Commit 4312a01

Browse files
committed
patch 'flex_attn.V.choices.get_flex_attention_fwd_configs'
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 92c5802 commit 4312a01

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,20 @@
99

1010
import torch
1111
import torch.nn.functional as F
12-
import torch._inductor.template_heuristics as inductor_heuristics
12+
import torch._inductor
13+
import torch._inductor.lowering
14+
import torch._inductor.kernel
15+
import torch._inductor.kernel.flex_attention as flex_attn
1316
from torch._inductor.template_heuristics import FlexConfig
1417

1518
import triton_kernels_benchmark as benchmark_suit
1619

1720
# Use TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 or uncomment the following line to print the auto-tune results.
1821
# torch._inductor.config.max_autotune_gemm = True
1922

20-
old_get_flex_attn_fwd_configs = inductor_heuristics.XPUConfigHeuristic.get_flex_attn_fwd_configs
21-
2223

2324
def get_flex_attn_fwd_configs(*args, **kwargs): # pylint: disable=unused-argument
24-
configs = old_get_flex_attn_fwd_configs(*args, **kwargs)
25-
# Add our own configurations for FlexAttention forward pass.
26-
# BLOCK_M, BLOCK_N, num_stages, num_warps
27-
configs += [
25+
configs = [
2826
FlexConfig(32, 16, 2, 4),
2927
FlexConfig(128, 64, 2, 16),
3028
FlexConfig(128, 64, 2, 8),
@@ -36,8 +34,8 @@ def get_flex_attn_fwd_configs(*args, **kwargs): # pylint: disable=unused-argume
3634

3735
# There is a auto-tuning requirement to get the best configuration for the flex attention.
3836
# The pytorch flex attention doesn't support auto-tuning by user by default.
39-
# Overriding the _get_xpu_config method to provide custom configurations for auto-tuning on XPU.
40-
inductor_heuristics.XPUConfigHeuristic.get_flex_attn_fwd_configs = get_flex_attn_fwd_configs # pylint: disable=protected-access
37+
# Overriding the get_flex_attention_fwd_configs method to provide custom configurations for auto-tuning on XPU.
38+
flex_attn.V.choices.get_flex_attention_fwd_configs = get_flex_attn_fwd_configs
4139

4240
torch._dynamo.config.recompile_limit = 100 # pylint: disable=protected-access
4341

0 commit comments

Comments
 (0)