@@ -525,7 +525,7 @@ index e78cf68244ee..79cb9d102bdd 100644
525
525
if __name__ == "__main__":
526
526
from torch._inductor.test_case import run_tests
527
527
diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py
528
- index 5af7b284f757..667915c69507 100644
528
+ index b5ec59dc291c..777892a0ce2d 100644
529
529
--- a/test/inductor/test_flex_decoding.py
530
530
+++ b/test/inductor/test_flex_decoding.py
531
531
@@ -27,6 +27,7 @@
@@ -598,7 +598,7 @@ index 5af7b284f757..667915c69507 100644
598
598
requires_grad=True,
599
599
)
600
600
q, k, v, backward_grad = make_q(), make_kv(), make_kv(), make_q()
601
- @@ -998 ,12 +1007 ,12 @@ def mask_mod(b, h, q, kv):
601
+ @@ -999 ,12 +1008 ,12 @@ def mask_mod(b, h, q, kv):
602
602
603
603
@supported_platform
604
604
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
@@ -615,7 +615,7 @@ index 5af7b284f757..667915c69507 100644
615
615
616
616
def score_mod(score, b, h, q, kv):
617
617
return score + offset_kv[kv] + offset_q[q]
618
- @@ -1011 ,8 +1020 ,14 @@ def score_mod(score, b, h, q, kv):
618
+ @@ -1012 ,8 +1021 ,14 @@ def score_mod(score, b, h, q, kv):
619
619
def mask_mod(b, h, q, kv):
620
620
return kv >= q + offset_tensor
621
621
@@ -632,7 +632,7 @@ index 5af7b284f757..667915c69507 100644
632
632
633
633
@supported_platform
634
634
@common_utils.parametrize("dtype", test_dtypes_fast)
635
- @@ -1677 ,19 +1692 ,19 @@ def mask_mod(b, h, q, kv):
635
+ @@ -1679 ,19 +1694 ,19 @@ def mask_mod(b, h, q, kv):
636
636
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
637
637
@common_utils.parametrize("dtype", test_dtypes)
638
638
@common_utils.parametrize("score_mod", [_identity, _causal])
@@ -655,7 +655,7 @@ index 5af7b284f757..667915c69507 100644
655
655
requires_grad=True,
656
656
)
657
657
q, k, v = make_q(), make_kv(), make_kv()
658
- @@ -1729 ,19 +1744 ,19 @@ def eager_sdpa_hop(q, k, v, score_mod):
658
+ @@ -1731 ,19 +1746 ,19 @@ def eager_sdpa_hop(q, k, v, score_mod):
659
659
660
660
@supported_platform
661
661
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
@@ -678,7 +678,7 @@ index 5af7b284f757..667915c69507 100644
678
678
requires_grad=True,
679
679
)
680
680
681
- @@ -1993 ,7 +2008 ,9 @@ def causal_mask(b, h, q, kv):
681
+ @@ -1995 ,7 +2010 ,9 @@ def causal_mask(b, h, q, kv):
682
682
self._check_equal(golden_outs, ref_outs, paged_out, fudge_factor, "Out")
683
683
684
684
@@ -689,11 +689,18 @@ index 5af7b284f757..667915c69507 100644
689
689
690
690
if __name__ == "__main__":
691
691
from torch._inductor.test_case import run_tests
692
- diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py
693
- index e471332afe71..ced92fae6229 100644
694
- --- a/torch/_inductor/kernel/flex_attention.py
695
- +++ b/torch/_inductor/kernel/flex_attention.py
696
- @@ -1445,7 +1445,9 @@ def flex_attention(
692
+ diff --git a/third_party/xpu.txt b/third_party/xpu.txt
693
+ index f3cfe7166aa7..d13f6ae35d03 100644
694
+ --- a/third_party/xpu.txt
695
+ +++ b/third_party/xpu.txt
696
+ @@ -1 +1 @@
697
+ - 3a9419c8bb6a98dd3e3cd473c36691fb4abeae40
698
+ + 3f07dd52aac2e466c3c3efc15f88118f21428272
699
+ diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py
700
+ index 0553fd06755d..d094a48627fb 100644
701
+ --- a/torch/_inductor/kernel/flex/flex_attention.py
702
+ +++ b/torch/_inductor/kernel/flex/flex_attention.py
703
+ @@ -531,7 +531,9 @@ def flex_attention(
697
704
698
705
dtype = query.get_dtype()
699
706
head_dim = V.graph.sizevars.guard_int(query.get_size()[-1])
@@ -704,7 +711,7 @@ index e471332afe71..ced92fae6229 100644
704
711
705
712
# Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
706
713
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
707
- @@ -2567 ,7 +2569 ,9 @@ def flex_attention_backward(*args, **kwargs):
714
+ @@ -1653 ,7 +1655 ,9 @@ def flex_attention_backward(*args, **kwargs):
708
715
709
716
dtype = query.get_dtype()
710
717
head_dim = V.graph.sizevars.guard_int(query.get_size()[-1])
@@ -715,11 +722,11 @@ index e471332afe71..ced92fae6229 100644
715
722
716
723
# Default config for warp specialization
717
724
num_consumer_groups, num_buffers_warp_spec = 0, 0
718
- diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py
719
- index 7e0aef981856..628bfc6419be 100644
720
- --- a/torch/_inductor/kernel/flex_decoding.py
721
- +++ b/torch/_inductor/kernel/flex_decoding.py
722
- @@ -310 ,7 +310 ,10 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
725
+ diff --git a/torch/_inductor/kernel/flex/ flex_decoding.py b/torch/_inductor/kernel/flex /flex_decoding.py
726
+ index 83c6b59cec96..e89981286ed8 100644
727
+ --- a/torch/_inductor/kernel/flex/ flex_decoding.py
728
+ +++ b/torch/_inductor/kernel/flex/ flex_decoding.py
729
+ @@ -354 ,7 +354 ,10 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
723
730
724
731
725
732
def get_split_k(B: int, H: int, Mk: int) -> int:
@@ -731,7 +738,7 @@ index 7e0aef981856..628bfc6419be 100644
731
738
bh = max(B * H, 1) # NOTE: Handle B*h=0 case
732
739
assert isinstance(bh, (int, sympy.Integer)), "B and H must be concrete integers"
733
740
split_k = num_SM // bh * 2 # Each SM should at least get one block.
734
- @@ -415 ,7 +418 ,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
741
+ @@ -458 ,7 +461 ,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
735
742
choices: list[Any] = []
736
743
dtype = key.get_dtype()
737
744
head_dim = V.graph.sizevars.guard_int(key.get_size()[-1])
@@ -742,24 +749,31 @@ index 7e0aef981856..628bfc6419be 100644
742
749
743
750
# TODO: fix autotuning.
744
751
745
- @@ -462 ,7 +467 ,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
752
+ @@ -505 ,7 +510 ,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
746
753
)
747
754
* gqa_shared_heads
748
755
),
749
756
- 16,
750
- + float('-inf') if torch.xpu.is_available() else 16,
757
+ + 1 if torch.xpu.is_available() else 16,
751
758
)
752
759
),
753
760
)
754
761
diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py
755
- index 65a6851192a0..3a53f0eed52e 100644
762
+ index eec1d055ddf7..f7a5aefb5cd1 100644
756
763
--- a/torch/_inductor/template_heuristics.py
757
764
+++ b/torch/_inductor/template_heuristics.py
758
- @@ -1201,3 +1201,87 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
759
- """
765
+ @@ -3,6 +3,7 @@
766
+ import dataclasses
767
+ import itertools
768
+ import math
769
+ + import os
770
+ from functools import partial
771
+ from threading import Lock
772
+ from typing import Any, Callable, Optional, TYPE_CHECKING
773
+ @@ -1203,6 +1204,97 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
760
774
Placeholder child class for XPU specific overrides.
761
775
"""
762
- +
776
+
763
777
+ def __init__(self) -> None:
764
778
+ super().__init__()
765
779
+
@@ -804,6 +818,9 @@ index 65a6851192a0..3a53f0eed52e 100644
804
818
+
805
819
+ def get_flex_attn_bwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]:
806
820
+ flex_attn_bwd_configs: list[FlexConfig] = []
821
+ + TRITON_LESS_FLEX_ATTN_BWD_CONFIGS = os.getenv(
822
+ + "TRITON_LESS_FLEX_ATTN_BWD_CONFIGS", "0"
823
+ + ).lower() in {"true", "1", "t", "y", "yes", "on"}
807
824
+
808
825
+ if config.max_autotune:
809
826
+ if config.max_autotune_flex_search_space == "EXHAUSTIVE":
@@ -825,6 +842,10 @@ index 65a6851192a0..3a53f0eed52e 100644
825
842
+ if default_config not in flex_attn_bwd_configs:
826
843
+ flex_attn_bwd_configs.append(default_config)
827
844
+
845
+ + if TRITON_LESS_FLEX_ATTN_BWD_CONFIGS:
846
+ + flex_attn_bwd_configs = list(
847
+ + filter(lambda c: c.num_stages == 1, flex_attn_bwd_configs)
848
+ + )
828
849
+ return flex_attn_bwd_configs
829
850
+
830
851
+ def get_flex_decode_configs(
@@ -843,8 +864,12 @@ index 65a6851192a0..3a53f0eed52e 100644
843
864
+ flex_decode_configs.append(default_config)
844
865
+
845
866
+ return flex_decode_configs
867
+ +
868
+
869
+ class MTIAConfigHeuristic(BaseConfigHeuristic):
870
+ """
846
871
diff --git a/torch/_ops.py b/torch/_ops.py
847
- index fecfebaeaa53..8fac24a8579c 100644
872
+ index 83a5dc0e57a5..b351aa17dfa7 100644
848
873
--- a/torch/_ops.py
849
874
+++ b/torch/_ops.py
850
875
@@ -267,6 +267,7 @@ def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type]
0 commit comments