Skip to content

Commit 869ed8a

Browse files
[FlexAttn] Fix performance degradation (#5038)
Tensor descriptor implementation is not used without this patch. The change in `flex_attention.py` is removed from pytorch/pytorch#143553 before merging. The requirement in `can_use_tma` was too restrictive for using tensor descriptor. Fixes #5036 Benchmark CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/17456013457 Signed-off-by: Whitney Tsang <[email protected]>
1 parent cfb23d7 commit 869ed8a

File tree

3 files changed

+50
-22
lines changed

3 files changed

+50
-22
lines changed

scripts/patch-pytorch.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,4 @@ echo "Applying PyTorch patches in $REPO_ROOT"
3636

3737
# put your patch applies here
3838
apply_patch ./patch/flex_decoding_tensor_desc.patch
39+
apply_patch ./patch/use_tma.patch

scripts/patch/flex_decoding_tensor_desc.patch

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,3 @@
1-
diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py
2-
index 91ba941da0..a6b87212ad 100644
3-
--- a/torch/_inductor/kernel/flex/flex_decoding.py
4-
+++ b/torch/_inductor/kernel/flex/flex_decoding.py
5-
@@ -6,6 +6,7 @@ from typing import Any
6-
import sympy
7-
8-
import torch
9-
+from torch._inductor.utils import can_use_tma
10-
from torch._inductor.virtualized import V
11-
12-
from ... import ir
13-
@@ -326,6 +327,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
14-
# Set default to False
15-
cur_kernel_options.setdefault("USE_TMA", False)
16-
17-
+ if torch.xpu.is_available() and can_use_tma(query, key, value):
18-
+ cur_kernel_options["USE_TMA"] = True
19-
+
20-
# Add ROCm-specific parameters if they exist in the config
21-
for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]:
22-
if hasattr(conf, attrib):
231
diff --git a/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja b/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja
242
index 31c64055e3..a75792787a 100644
253
--- a/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja

scripts/patch/use_tma.patch

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py
2+
index 39c8f737c7..5a8df2c3f9 100644
3+
--- a/torch/_inductor/kernel/flex/flex_attention.py
4+
+++ b/torch/_inductor/kernel/flex/flex_attention.py
5+
@@ -311,6 +311,9 @@ def flex_attention(
6+
# USE TMA = false by default
7+
cur_kernel_options.setdefault("USE_TMA", False)
8+
9+
+ if torch.xpu.is_available() and can_use_tma(query, key, value):
10+
+ cur_kernel_options["USE_TMA"] = True
11+
+
12+
if cur_kernel_options["USE_TMA"] and can_use_tma(query, key, value):
13+
cur_kernel_options["USE_TMA"] = True
14+
15+
diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py
16+
index 91ba941da0..a6b87212ad 100644
17+
--- a/torch/_inductor/kernel/flex/flex_decoding.py
18+
+++ b/torch/_inductor/kernel/flex/flex_decoding.py
19+
@@ -6,6 +6,7 @@ from typing import Any
20+
import sympy
21+
22+
import torch
23+
+from torch._inductor.utils import can_use_tma
24+
from torch._inductor.virtualized import V
25+
26+
from ... import ir
27+
@@ -326,6 +327,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
28+
# Set default to False
29+
cur_kernel_options.setdefault("USE_TMA", False)
30+
31+
+ if torch.xpu.is_available() and can_use_tma(query, key, value):
32+
+ cur_kernel_options["USE_TMA"] = True
33+
+
34+
# Add ROCm-specific parameters if they exist in the config
35+
for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]:
36+
if hasattr(conf, attrib):
37+
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
38+
index 0876f99307..4fa1c87560 100644
39+
--- a/torch/_inductor/utils.py
40+
+++ b/torch/_inductor/utils.py
41+
@@ -1696,7 +1696,7 @@ def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool:
42+
strides_i = [V.graph.sizevars.symbolic_hint(st) for st in strides]
43+
44+
# Every logical size ≥ 2
45+
- if any(not V.graph.sizevars.statically_known_geq(s, 2) for s in sizes_i):
46+
+ if not torch.xpu.is_available() and any(not V.graph.sizevars.statically_known_geq(s, 2) for s in sizes_i):
47+
return False
48+
49+
# Find the single contiguous (“inner”) dim

0 commit comments

Comments
 (0)