|
| 1 | +diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py |
| 2 | +index 39c8f737c7..5a8df2c3f9 100644 |
| 3 | +--- a/torch/_inductor/kernel/flex/flex_attention.py |
| 4 | ++++ b/torch/_inductor/kernel/flex/flex_attention.py |
| 5 | +@@ -311,6 +311,9 @@ def flex_attention( |
| 6 | + # USE TMA = false by default |
| 7 | + cur_kernel_options.setdefault("USE_TMA", False) |
| 8 | + |
| 9 | ++ if torch.xpu.is_available() and can_use_tma(query, key, value): |
| 10 | ++ cur_kernel_options["USE_TMA"] = True |
| 11 | ++ |
| 12 | + if cur_kernel_options["USE_TMA"] and can_use_tma(query, key, value): |
| 13 | + cur_kernel_options["USE_TMA"] = True |
| 14 | + |
| 15 | +diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py |
| 16 | +index 91ba941da0..a6b87212ad 100644 |
| 17 | +--- a/torch/_inductor/kernel/flex/flex_decoding.py |
| 18 | ++++ b/torch/_inductor/kernel/flex/flex_decoding.py |
| 19 | +@@ -6,6 +6,7 @@ from typing import Any |
| 20 | + import sympy |
| 21 | + |
| 22 | + import torch |
| 23 | ++from torch._inductor.utils import can_use_tma |
| 24 | + from torch._inductor.virtualized import V |
| 25 | + |
| 26 | + from ... import ir |
| 27 | +@@ -326,6 +327,9 @@ def create_flex_decoding_kernel(*args, **kwargs): |
| 28 | + # Set default to False |
| 29 | + cur_kernel_options.setdefault("USE_TMA", False) |
| 30 | + |
| 31 | ++ if torch.xpu.is_available() and can_use_tma(query, key, value): |
| 32 | ++ cur_kernel_options["USE_TMA"] = True |
| 33 | ++ |
| 34 | + # Add ROCm-specific parameters if they exist in the config |
| 35 | + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: |
| 36 | + if hasattr(conf, attrib): |
| 37 | +diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py |
| 38 | +index 0876f99307..4fa1c87560 100644 |
| 39 | +--- a/torch/_inductor/utils.py |
| 40 | ++++ b/torch/_inductor/utils.py |
| 41 | +@@ -1696,7 +1696,7 @@ def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool: |
| 42 | + strides_i = [V.graph.sizevars.symbolic_hint(st) for st in strides] |
| 43 | + |
| 44 | + # Every logical size ≥ 2 |
| 45 | +- if any(not V.graph.sizevars.statically_known_geq(s, 2) for s in sizes_i): |
| 46 | ++ if not torch.xpu.is_available() and any(not V.graph.sizevars.statically_known_geq(s, 2) for s in sizes_i): |
| 47 | + return False |
| 48 | + |
| 49 | + # Find the single contiguous (“inner”) dim |
0 commit comments