Skip to content

Commit 10bddf6

Browse files
[PATCH] Update flex_decoding_tensor_desc.patch (#5019)
PyTorch pin updated in #5006 removed block pointer implementation. Signed-off-by: Whitney Tsang <[email protected]>
1 parent e835880 commit 10bddf6

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

scripts/patch/flex_decoding_tensor_desc.patch

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py
2-
index 91ba941da06..2eba71224c6 100644
2+
index 91ba941da0..a6b87212ad 100644
33
--- a/torch/_inductor/kernel/flex/flex_decoding.py
44
+++ b/torch/_inductor/kernel/flex/flex_decoding.py
55
@@ -6,6 +6,7 @@ from typing import Any
@@ -10,19 +10,18 @@ index 91ba941da06..2eba71224c6 100644
1010
from torch._inductor.virtualized import V
1111

1212
from ... import ir
13-
@@ -326,6 +327,10 @@ def create_flex_decoding_kernel(*args, **kwargs):
13+
@@ -326,6 +327,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
1414
# Set default to False
1515
cur_kernel_options.setdefault("USE_TMA", False)
1616

17-
+ # Change to True if block pointer implementation is removed.
1817
+ if torch.xpu.is_available() and can_use_tma(query, key, value):
19-
+ cur_kernel_options["USE_TMA"] = False
18+
+ cur_kernel_options["USE_TMA"] = True
2019
+
2120
# Add ROCm-specific parameters if they exist in the config
2221
for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]:
2322
if hasattr(conf, attrib):
2423
diff --git a/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja b/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja
25-
index 31c64055e35..a75792787a1 100644
24+
index 31c64055e3..a75792787a 100644
2625
--- a/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja
2726
+++ b/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja
2827
@@ -130,10 +130,27 @@

0 commit comments

Comments
 (0)