11diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py
2- index 91ba941da06..2eba71224c6 100644
2+ index 91ba941da0..a6b87212ad 100644
33--- a/torch/_inductor/kernel/flex/flex_decoding.py
44+++ b/torch/_inductor/kernel/flex/flex_decoding.py
55@@ -6,6 +6,7 @@ from typing import Any
@@ -10,19 +10,18 @@ index 91ba941da06..2eba71224c6 100644
1010 from torch._inductor.virtualized import V
1111
1212 from ... import ir
13- @@ -326,6 +327,10 @@ def create_flex_decoding_kernel(*args, **kwargs):
13+ @@ -326,6 +327,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
1414 # Set default to False
1515 cur_kernel_options.setdefault("USE_TMA", False)
1616
17- + # Change to True if block pointer implementation is removed.
1817+ if torch.xpu.is_available() and can_use_tma(query, key, value):
19- + cur_kernel_options["USE_TMA"] = False
18+ + cur_kernel_options["USE_TMA"] = True
2019+
2120 # Add ROCm-specific parameters if they exist in the config
2221 for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]:
2322 if hasattr(conf, attrib):
2423diff --git a/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja b/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja
25- index 31c64055e35..a75792787a1 100644
24+ index 31c64055e3..a75792787a 100644
2625--- a/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja
2726+++ b/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja
2827@@ -130,10 +130,27 @@
0 commit comments