@@ -525,7 +525,7 @@ index e78cf68244ee..79cb9d102bdd 100644
525525 if __name__ == "__main__":
526526 from torch._inductor.test_case import run_tests
527527diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py
528- index 5af7b284f757..667915c69507 100644
528+ index b5ec59dc291c..777892a0ce2d 100644
529529--- a/test/inductor/test_flex_decoding.py
530530+++ b/test/inductor/test_flex_decoding.py
531531@@ -27,6 +27,7 @@
@@ -598,7 +598,7 @@ index 5af7b284f757..667915c69507 100644
598598 requires_grad=True,
599599 )
600600 q, k, v, backward_grad = make_q(), make_kv(), make_kv(), make_q()
601- @@ -998 ,12 +1007 ,12 @@ def mask_mod(b, h, q, kv):
601+ @@ -999 ,12 +1008 ,12 @@ def mask_mod(b, h, q, kv):
602602
603603 @supported_platform
604604 @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
@@ -615,7 +615,7 @@ index 5af7b284f757..667915c69507 100644
615615
616616 def score_mod(score, b, h, q, kv):
617617 return score + offset_kv[kv] + offset_q[q]
618- @@ -1011 ,8 +1020 ,14 @@ def score_mod(score, b, h, q, kv):
618+ @@ -1012 ,8 +1021 ,14 @@ def score_mod(score, b, h, q, kv):
619619 def mask_mod(b, h, q, kv):
620620 return kv >= q + offset_tensor
621621
@@ -632,7 +632,7 @@ index 5af7b284f757..667915c69507 100644
632632
633633 @supported_platform
634634 @common_utils.parametrize("dtype", test_dtypes_fast)
635- @@ -1677 ,19 +1692 ,19 @@ def mask_mod(b, h, q, kv):
635+ @@ -1679 ,19 +1694 ,19 @@ def mask_mod(b, h, q, kv):
636636 @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
637637 @common_utils.parametrize("dtype", test_dtypes)
638638 @common_utils.parametrize("score_mod", [_identity, _causal])
@@ -655,7 +655,7 @@ index 5af7b284f757..667915c69507 100644
655655 requires_grad=True,
656656 )
657657 q, k, v = make_q(), make_kv(), make_kv()
658- @@ -1729 ,19 +1744 ,19 @@ def eager_sdpa_hop(q, k, v, score_mod):
658+ @@ -1731 ,19 +1746 ,19 @@ def eager_sdpa_hop(q, k, v, score_mod):
659659
660660 @supported_platform
661661 @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
@@ -678,7 +678,7 @@ index 5af7b284f757..667915c69507 100644
678678 requires_grad=True,
679679 )
680680
681- @@ -1993 ,7 +2008 ,9 @@ def causal_mask(b, h, q, kv):
681+ @@ -1995 ,7 +2010 ,9 @@ def causal_mask(b, h, q, kv):
682682 self._check_equal(golden_outs, ref_outs, paged_out, fudge_factor, "Out")
683683
684684
@@ -689,11 +689,18 @@ index 5af7b284f757..667915c69507 100644
689689
690690 if __name__ == "__main__":
691691 from torch._inductor.test_case import run_tests
692- diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py
693- index e471332afe71..ced92fae6229 100644
694- --- a/torch/_inductor/kernel/flex_attention.py
695- +++ b/torch/_inductor/kernel/flex_attention.py
696- @@ -1445,7 +1445,9 @@ def flex_attention(
692+ diff --git a/third_party/xpu.txt b/third_party/xpu.txt
693+ index f3cfe7166aa7..d13f6ae35d03 100644
694+ --- a/third_party/xpu.txt
695+ +++ b/third_party/xpu.txt
696+ @@ -1 +1 @@
697+ - 3a9419c8bb6a98dd3e3cd473c36691fb4abeae40
698+ + 3f07dd52aac2e466c3c3efc15f88118f21428272
699+ diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py
700+ index 0553fd06755d..d094a48627fb 100644
701+ --- a/torch/_inductor/kernel/flex/flex_attention.py
702+ +++ b/torch/_inductor/kernel/flex/flex_attention.py
703+ @@ -531,7 +531,9 @@ def flex_attention(
697704
698705 dtype = query.get_dtype()
699706 head_dim = V.graph.sizevars.guard_int(query.get_size()[-1])
@@ -704,7 +711,7 @@ index e471332afe71..ced92fae6229 100644
704711
705712 # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
706713 SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
707- @@ -2567 ,7 +2569 ,9 @@ def flex_attention_backward(*args, **kwargs):
714+ @@ -1653 ,7 +1655 ,9 @@ def flex_attention_backward(*args, **kwargs):
708715
709716 dtype = query.get_dtype()
710717 head_dim = V.graph.sizevars.guard_int(query.get_size()[-1])
@@ -715,11 +722,11 @@ index e471332afe71..ced92fae6229 100644
715722
716723 # Default config for warp specialization
717724 num_consumer_groups, num_buffers_warp_spec = 0, 0
718- diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py
719- index 7e0aef981856..628bfc6419be 100644
720- --- a/torch/_inductor/kernel/flex_decoding.py
721- +++ b/torch/_inductor/kernel/flex_decoding.py
722- @@ -310 ,7 +310 ,10 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
725+ diff --git a/torch/_inductor/kernel/flex/ flex_decoding.py b/torch/_inductor/kernel/flex /flex_decoding.py
726+ index 83c6b59cec96..e89981286ed8 100644
727+ --- a/torch/_inductor/kernel/flex/ flex_decoding.py
728+ +++ b/torch/_inductor/kernel/flex/ flex_decoding.py
729+ @@ -354 ,7 +354 ,10 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
723730
724731
725732 def get_split_k(B: int, H: int, Mk: int) -> int:
@@ -731,7 +738,7 @@ index 7e0aef981856..628bfc6419be 100644
731738 bh = max(B * H, 1) # NOTE: Handle B*h=0 case
732739 assert isinstance(bh, (int, sympy.Integer)), "B and H must be concrete integers"
733740 split_k = num_SM // bh * 2 # Each SM should at least get one block.
734- @@ -415 ,7 +418 ,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
741+ @@ -458 ,7 +461 ,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
735742 choices: list[Any] = []
736743 dtype = key.get_dtype()
737744 head_dim = V.graph.sizevars.guard_int(key.get_size()[-1])
@@ -742,24 +749,31 @@ index 7e0aef981856..628bfc6419be 100644
742749
743750 # TODO: fix autotuning.
744751
745- @@ -462 ,7 +467 ,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
752+ @@ -505 ,7 +510 ,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
746753 )
747754 * gqa_shared_heads
748755 ),
749756- 16,
750- + float('-inf') if torch.xpu.is_available() else 16,
757+ + 1 if torch.xpu.is_available() else 16,
751758 )
752759 ),
753760 )
754761diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py
755- index 65a6851192a0..3a53f0eed52e 100644
762+ index eec1d055ddf7..f7a5aefb5cd1 100644
756763--- a/torch/_inductor/template_heuristics.py
757764+++ b/torch/_inductor/template_heuristics.py
758- @@ -1201,3 +1201,87 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
759- """
765+ @@ -3,6 +3,7 @@
766+ import dataclasses
767+ import itertools
768+ import math
769+ + import os
770+ from functools import partial
771+ from threading import Lock
772+ from typing import Any, Callable, Optional, TYPE_CHECKING
773+ @@ -1203,6 +1204,97 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
760774 Placeholder child class for XPU specific overrides.
761775 """
762- +
776+
763777+ def __init__(self) -> None:
764778+ super().__init__()
765779+
@@ -804,6 +818,9 @@ index 65a6851192a0..3a53f0eed52e 100644
804818+
805819+ def get_flex_attn_bwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]:
806820+ flex_attn_bwd_configs: list[FlexConfig] = []
821+ + TRITON_LESS_FLEX_ATTN_BWD_CONFIGS = os.getenv(
822+ + "TRITON_LESS_FLEX_ATTN_BWD_CONFIGS", "0"
823+ + ).lower() in {"true", "1", "t", "y", "yes", "on"}
807824+
808825+ if config.max_autotune:
809826+ if config.max_autotune_flex_search_space == "EXHAUSTIVE":
@@ -825,6 +842,10 @@ index 65a6851192a0..3a53f0eed52e 100644
825842+ if default_config not in flex_attn_bwd_configs:
826843+ flex_attn_bwd_configs.append(default_config)
827844+
845+ + if TRITON_LESS_FLEX_ATTN_BWD_CONFIGS:
846+ + flex_attn_bwd_configs = list(
847+ + filter(lambda c: c.num_stages == 1, flex_attn_bwd_configs)
848+ + )
828849+ return flex_attn_bwd_configs
829850+
830851+ def get_flex_decode_configs(
@@ -843,8 +864,12 @@ index 65a6851192a0..3a53f0eed52e 100644
843864+ flex_decode_configs.append(default_config)
844865+
845866+ return flex_decode_configs
867+ +
868+
869+ class MTIAConfigHeuristic(BaseConfigHeuristic):
870+ """
846871diff --git a/torch/_ops.py b/torch/_ops.py
847- index fecfebaeaa53..8fac24a8579c 100644
872+ index 83a5dc0e57a5..b351aa17dfa7 100644
848873--- a/torch/_ops.py
849874+++ b/torch/_ops.py
850875@@ -267,6 +267,7 @@ def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type]
0 commit comments