11diff --git a/.ci/docker/common/install_xpu.sh b/.ci/docker/common/install_xpu.sh
2- index 51e9df623d5d1a..647e77f6d17bdc 100644
2+ index ecbbb8ccccf897..6349a7c6829c77 100644
33--- a/.ci/docker/common/install_xpu.sh
44+++ b/.ci/docker/common/install_xpu.sh
55@@ -35,12 +35,12 @@ function install_ubuntu() {
@@ -33,7 +33,7 @@ index a0e7dce3df4d55..9cd30e0178bf92 100644
3333 RUN bash ./install_xpu.sh && rm install_xpu.sh
3434
3535diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py
36- index 4d14555800c8c4..ab40d19d2ff5ee 100644
36+ index fa6400dd9c2724..b72d8e9021fc58 100644
3737--- a/test/inductor/test_flex_attention.py
3838+++ b/test/inductor/test_flex_attention.py
3939@@ -41,20 +41,26 @@
@@ -450,32 +450,32 @@ index 4d14555800c8c4..ab40d19d2ff5ee 100644
450450 dtype=torch.bfloat16,
451451 )
452452 query, key, value = make_tensor(), make_tensor(), make_tensor()
453- @@ -4730 ,6 +4803 ,7 @@ def flex_attention_fn():
453+ @@ -4722 ,6 +4795 ,7 @@ def flex_attention_fn():
454454 )
455455
456456 @supported_platform
457457+ @skip_on_xpu
458458 def test_create_is_cuda_graphable(self, device):
459459 def mask_mod(b, h, q, kv):
460460 return q >= kv
461- @@ -4771 ,7 +4845 ,7 @@ def create_inputs(S ):
462- flex_attention_call(*create_inputs(1024), block_mask=block_mask )
461+ @@ -4903 ,7 +4977 ,7 @@ def test_block_mask_operations_with_none_q_indices(self, device ):
462+ self.assertIsNone(cpu_mask.q_indices )
463463
464464
465465- @large_tensor_test_class("2GB", device="cuda")
466466+ @large_tensor_test_class("2GB", device=test_device[0])
467467 class TestPagedAttention(InductorTestCase):
468468 def setUp(self):
469469 super().setUp()
470- @@ -5086 ,6 +5160 ,7 @@ def test_update(self, device):
470+ @@ -5218 ,6 +5292 ,7 @@ def test_update(self, device):
471471 @supported_platform
472472 @dtypes(*device_configs["cpu"].dtypes)
473473 @dtypesIfCUDA(*device_configs["cuda"].dtypes)
474474+ @dtypesIfXPU(*device_configs["xpu"].dtypes)
475475 @common_utils.parametrize("score_mod", test_score_mods)
476476 def test_paged_builtin_score_mods(
477477 self, device, dtype: torch.dtype, score_mod: Callable
478- @@ -5214 ,14 +5289 ,17 @@ def get_params(dtypes: list[torch.dtype]) -> list[Params]:
478+ @@ -5346 ,14 +5421 ,17 @@ def get_params(dtypes: list[torch.dtype]) -> list[Params]:
479479
480480
481481 supports_learnable_bias = unittest.skipUnless(
@@ -497,7 +497,7 @@ index 4d14555800c8c4..ab40d19d2ff5ee 100644
497497 class TestLearnableBiases(InductorTestCase):
498498 def setUp(self):
499499 super().setUp()
500- @@ -6112 ,10 +6190 ,22 @@ def _test_learnable_bias_inner(
500+ @@ -6244 ,10 +6322 ,22 @@ def _test_learnable_bias_inner(
501501 )
502502
503503
@@ -691,7 +691,7 @@ index 3b4905fc356168..3a165e5fff2eda 100644
691691 if __name__ == "__main__":
692692 from torch._inductor.test_case import run_tests
693693diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py
694- index 99e869dc8fdb71..1426f000d191d2 100644
694+ index 9a7507631cc490..1a761bf3833e56 100644
695695--- a/torch/_inductor/kernel/flex_attention.py
696696+++ b/torch/_inductor/kernel/flex_attention.py
697697@@ -1441,7 +1441,9 @@ def flex_attention(
@@ -705,7 +705,7 @@ index 99e869dc8fdb71..1426f000d191d2 100644
705705
706706 # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
707707 SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
708- @@ -2557 ,7 +2559 ,9 @@ def flex_attention_backward(*args, **kwargs):
708+ @@ -2560 ,7 +2562 ,9 @@ def flex_attention_backward(*args, **kwargs):
709709
710710 dtype = query.get_dtype()
711711 head_dim = V.graph.sizevars.guard_int(query.get_size()[-1])
@@ -717,7 +717,7 @@ index 99e869dc8fdb71..1426f000d191d2 100644
717717 # Default config for warp specialization
718718 num_consumer_groups, num_buffers_warp_spec = 0, 0
719719diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py
720- index 7e0aef98185603..343086e5b2d16a 100644
720+ index 7e0aef98185603..6c8bd4b593ae38 100644
721721--- a/torch/_inductor/kernel/flex_decoding.py
722722+++ b/torch/_inductor/kernel/flex_decoding.py
723723@@ -310,7 +310,10 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
@@ -743,11 +743,71 @@ index 7e0aef98185603..343086e5b2d16a 100644
743743
744744 # TODO: fix autotuning.
745745
746+ @@ -448,24 +453,41 @@ def create_flex_decoding_kernel(*args, **kwargs):
747+
748+ set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars)
749+
750+ - kernel_options.setdefault(
751+ - "BLOCK_M",
752+ - (
753+ - # m
754+ - # if V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 0))
755+ - # else # Always use a BLOCK_M > 16 before Triton fix https://github.com/triton-lang/triton/pull/4061 is in pin
756+ - max(
757+ - next_power_of_2(
758+ - V.graph.sizevars.size_hint(
759+ - seq_len_q,
760+ - fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
761+ - )
762+ - * gqa_shared_heads
763+ - ),
764+ - 16,
765+ - )
766+ - ),
767+ - )
768+ + if torch.xpu.is_available():
769+ + kernel_options.setdefault(
770+ + "BLOCK_M",
771+ + (
772+ + max(
773+ + next_power_of_2(
774+ + V.graph.sizevars.size_hint(
775+ + seq_len_q,
776+ + fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
777+ + )
778+ + * gqa_shared_heads
779+ + ),
780+ + 8,
781+ + )
782+ + ),
783+ + )
784+ + else:
785+ + kernel_options.setdefault(
786+ + "BLOCK_M",
787+ + (
788+ + # m
789+ + # if V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 0))
790+ + # else # Always use a BLOCK_M > 16 before Triton fix https://github.com/triton-lang/triton/pull/4061 is in pin
791+ + max(
792+ + next_power_of_2(
793+ + V.graph.sizevars.size_hint(
794+ + seq_len_q,
795+ + fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
796+ + )
797+ + * gqa_shared_heads
798+ + ),
799+ + 16,
800+ + )
801+ + ),
802+ + )
803+
804+ query = ir.ExternKernel.realize_input(query)
805+ stride_b, stride_hq, stride_seq_len_q, stride_qk_head_dim = query.get_stride()
746806diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py
747- index dfd37523a37027..b830ec6369a9d7 100644
807+ index 40a9645186792f..eaa6fbeaf0d4ea 100644
748808--- a/torch/_inductor/template_heuristics.py
749809+++ b/torch/_inductor/template_heuristics.py
750- @@ -1178 ,3 +1178 ,87 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
810+ @@ -1185 ,3 +1185 ,87 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
751811 """
752812 Placeholder child class for XPU specific overrides.
753813 """
@@ -836,7 +896,7 @@ index dfd37523a37027..b830ec6369a9d7 100644
836896+
837897+ return flex_decode_configs
838898diff --git a/torch/_ops.py b/torch/_ops.py
839- index 337b9a11e6a180..5e3423285e02b5 100644
899+ index 600f6d9e1ada1c..1121ced7eaa5ff 100644
840900--- a/torch/_ops.py
841901+++ b/torch/_ops.py
842902@@ -267,6 +267,7 @@ def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type]
@@ -848,10 +908,10 @@ index 337b9a11e6a180..5e3423285e02b5 100644
848908
849909
850910diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py
851- index 15a00e1a9d342b..3c9b3a20173997 100644
911+ index ce592c1ed342f8..bcc180184d9aa4 100644
852912--- a/torch/nn/attention/flex_attention.py
853913+++ b/torch/nn/attention/flex_attention.py
854- @@ -1142 ,11 +1142 ,8 @@ def _validate_device(query: Tensor, key: Tensor, value: Tensor):
914+ @@ -1146 ,11 +1146 ,8 @@ def _validate_device(query: Tensor, key: Tensor, value: Tensor):
855915 """TODO: Remove once non cuda/cpu devices support is added
856916 We only need to check query since we have already that q,k,v are on the same device
857917 """
0 commit comments