@@ -33,7 +33,7 @@ index a0e7dce3df4d..9cd30e0178bf 100644
3333 RUN bash ./install_xpu.sh && rm install_xpu.sh
3434
3535diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py
36- index e78cf68244ee..79cb9d102bdd 100644
36+ index 8e4746212a0b..31c914399fae 100644
3737--- a/test/inductor/test_flex_attention.py
3838+++ b/test/inductor/test_flex_attention.py
3939@@ -42,20 +42,26 @@
@@ -372,39 +372,39 @@ index e78cf68244ee..79cb9d102bdd 100644
372372 def test_captured_reduction(self, device, dtype):
373373 scale = torch.randn((B, 8), device=device)
374374
375- @@ -2296 ,6 +2364 ,7 @@ def f(q, k, v):
375+ @@ -2340 ,6 +2408 ,7 @@ def f(q, k, v):
376376 @supported_platform
377377 @dtypes(*device_configs["cpu"].dtypes)
378378 @dtypesIfCUDA(*device_configs["cuda"].dtypes)
379379+ @dtypesIfXPU(*device_configs["xpu"].dtypes)
380380 def test_njt_causal(self, device, dtype):
381381 offsets = torch.tensor(
382382 [0, 1024, 1024 + 512, S], device=device, dtype=torch.int32
383- @@ -2358 ,6 +2427 ,7 @@ def bias_mod(score, batch, head, token_q, token_kv):
383+ @@ -2402 ,6 +2471 ,7 @@ def bias_mod(score, batch, head, token_q, token_kv):
384384 @common_utils.parametrize("score_mod", test_score_mods)
385385 @dtypes(*device_configs["cpu"].dtypes)
386386 @dtypesIfCUDA(*device_configs["cuda"].dtypes)
387387+ @dtypesIfXPU(*device_configs["xpu"].dtypes)
388388 @common_utils.parametrize("head_dims", [(D, D // 2), (D // 2, D)])
389389 def test_non_equal_head_dims(self, device, dtype, score_mod, head_dims):
390390 qk_d, v_d = head_dims
391- @@ -2451 ,6 +2521 ,7 @@ def causal(b, h, q_idx, kv_idx):
391+ @@ -2495 ,6 +2565 ,7 @@ def causal(b, h, q_idx, kv_idx):
392392 @common_utils.parametrize("head_dim", [17, 24, 94, 121])
393393 @dtypes(*device_configs["cpu"].dtypes_fast)
394394 @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast)
395395+ @dtypesIfXPU(*device_configs["xpu"].dtypes_fast)
396396 def test_non_pow_2_headdim(self, device, dtype, head_dim):
397397 self.run_test(_rel_bias, dtype, device, B, H, S, head_dim, B, H, S, head_dim)
398398
399- @@ -2515 ,6 +2586 ,7 @@ def causal_constructor(S):
399+ @@ -2559 ,6 +2630 ,7 @@ def causal_constructor(S):
400400 @skip_on_cpu
401401 @dtypes(*device_configs["cpu"].dtypes)
402402 @dtypesIfCUDA(*device_configs["cuda"].dtypes)
403403+ @dtypesIfXPU(*device_configs["xpu"].dtypes)
404404 @common_utils.parametrize("score_mod", [_identity, _causal])
405405 def test_logsumexp_correctness(self, device, dtype, score_mod):
406406 make_tensor = functools.partial(
407- @@ -2971 ,7 +3043 ,7 @@ def test_flex_attention_backward_stride_ordering(
407+ @@ -3015 ,7 +3087 ,7 @@ def test_flex_attention_backward_stride_ordering(
408408 def test_non_contiguous_last_dim(self, device):
409409 """Test flex_attention with tensors having non contiguous last dimension."""
410410 B, H, D = 4, 8, 64
@@ -413,7 +413,7 @@ index e78cf68244ee..79cb9d102bdd 100644
413413 for S in [16, 64]:
414414
415415 def column_major_tensor():
416- @@ -3803 ,7 +3875 ,7 @@ def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3
416+ @@ -3847 ,7 +3919 ,7 @@ def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3
417417
418418 class mask_graph0(torch.nn.Module):
419419 def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"):
@@ -422,7 +422,7 @@ index e78cf68244ee..79cb9d102bdd 100644
422422 return full_default
423423 """.replace( # noqa: B950
424424 "GPU_TYPE", torch.device(device).type
425- @@ -4091 ,9 +4163 ,9 @@ def flex_attn_fn(x):
425+ @@ -4135 ,9 +4207 ,9 @@ def flex_attn_fn(x):
426426 return output
427427
428428 flex_module = SacModule(hidden_size=512, num_heads=8, context_fn=context_fn).to(
@@ -434,7 +434,7 @@ index e78cf68244ee..79cb9d102bdd 100644
434434
435435 # Run without compilation
436436 output_module = flex_module(x)
437- @@ -4188 ,12 +4260 ,13 @@ def make_tensor():
437+ @@ -4232 ,12 +4304 ,13 @@ def make_tensor():
438438
439439 @supported_platform
440440 @skip_on_cpu
@@ -450,15 +450,15 @@ index e78cf68244ee..79cb9d102bdd 100644
450450 dtype=torch.bfloat16,
451451 )
452452 query, key, value = make_tensor(), make_tensor(), make_tensor()
453- @@ -4777 ,6 +4850 ,7 @@ def flex_attention_fn():
453+ @@ -4821 ,6 +4894 ,7 @@ def flex_attention_fn():
454454 )
455455
456456 @supported_platform
457457+ @skip_on_xpu
458458 def test_create_is_cuda_graphable(self, device):
459459 def mask_mod(b, h, q, kv):
460460 return q >= kv
461- @@ -4958 ,7 +5032 ,7 @@ def test_block_mask_operations_with_none_q_indices(self, device):
461+ @@ -5002 ,7 +5076 ,7 @@ def test_block_mask_operations_with_none_q_indices(self, device):
462462 self.assertIsNone(cpu_mask.q_indices)
463463
464464
@@ -467,15 +467,15 @@ index e78cf68244ee..79cb9d102bdd 100644
467467 class TestPagedAttention(InductorTestCase):
468468 def setUp(self):
469469 super().setUp()
470- @@ -5273 ,6 +5347 ,7 @@ def test_update(self, device):
470+ @@ -5317 ,6 +5391 ,7 @@ def test_update(self, device):
471471 @supported_platform
472472 @dtypes(*device_configs["cpu"].dtypes)
473473 @dtypesIfCUDA(*device_configs["cuda"].dtypes)
474474+ @dtypesIfXPU(*device_configs["xpu"].dtypes)
475475 @common_utils.parametrize("score_mod", test_score_mods)
476476 def test_paged_builtin_score_mods(
477477 self, device, dtype: torch.dtype, score_mod: Callable
478- @@ -5401 ,14 +5476 ,17 @@ def get_params(dtypes: list[torch.dtype]) -> list[Params]:
478+ @@ -5445 ,14 +5520 ,17 @@ def get_params(dtypes: list[torch.dtype]) -> list[Params]:
479479
480480
481481 supports_learnable_bias = unittest.skipUnless(
@@ -497,7 +497,16 @@ index e78cf68244ee..79cb9d102bdd 100644
497497 class TestLearnableBiases(InductorTestCase):
498498 def setUp(self):
499499 super().setUp()
500- @@ -6299,10 +6377,22 @@ def _test_learnable_bias_inner(
500+ @@ -5505,7 +5583,7 @@ def _gold_check(self, eager, compiled, gold, tensor_name, fudge_factor=1.35):
501+ def _check_outputs_and_grads(
502+ self, out_eager, out_compiled, out_gold, tensors, names=None
503+ ):
504+ - backwards_grad = torch.randn_like(out_eager)
505+ + backwards_grad = torch.randn_like(out_eager, device="cpu").to(out_eager.device)
506+ grads_eager = torch.autograd.grad((out_eager,), tensors, backwards_grad)
507+ grads_compiled = torch.autograd.grad((out_compiled,), tensors, backwards_grad)
508+ grads_gold = torch.autograd.grad((out_gold,), tensors, backwards_grad)
509+ @@ -6343,10 +6421,22 @@ def _test_learnable_bias_inner(
501510 )
502511
503512
@@ -690,14 +699,14 @@ index b5ec59dc291c..777892a0ce2d 100644
690699 if __name__ == "__main__":
691700 from torch._inductor.test_case import run_tests
692701diff --git a/third_party/xpu.txt b/third_party/xpu.txt
693- index f3cfe7166aa7..d13f6ae35d03 100644
702+ index b84ebb55a901..42d53a213bd4 100644
694703--- a/third_party/xpu.txt
695704+++ b/third_party/xpu.txt
696705@@ -1 +1 @@
697- - 3a9419c8bb6a98dd3e3cd473c36691fb4abeae40
698- + 3f07dd52aac2e466c3c3efc15f88118f21428272
706+ - 1f7a57f50745a429b7da10dddf2e366687659b87
707+ + 2d6a5c68eca42378e0df9c92171f090eecdf5f96
699708diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py
700- index 0553fd06755d..d094a48627fb 100644
709+ index b6f5646bb57c..0cc877e75ebf 100644
701710--- a/torch/_inductor/kernel/flex/flex_attention.py
702711+++ b/torch/_inductor/kernel/flex/flex_attention.py
703712@@ -531,7 +531,9 @@ def flex_attention(
@@ -711,7 +720,7 @@ index 0553fd06755d..d094a48627fb 100644
711720
712721 # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
713722 SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
714- @@ -1653 ,7 +1655 ,9 @@ def flex_attention_backward(*args, **kwargs):
723+ @@ -1655 ,7 +1657 ,9 @@ def flex_attention_backward(*args, **kwargs):
715724
716725 dtype = query.get_dtype()
717726 head_dim = V.graph.sizevars.guard_int(query.get_size()[-1])
@@ -723,7 +732,7 @@ index 0553fd06755d..d094a48627fb 100644
723732 # Default config for warp specialization
724733 num_consumer_groups, num_buffers_warp_spec = 0, 0
725734diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py
726- index 83c6b59cec96..e89981286ed8 100644
735+ index 7f92fbc705a5..c5868cb21bae 100644
727736--- a/torch/_inductor/kernel/flex/flex_decoding.py
728737+++ b/torch/_inductor/kernel/flex/flex_decoding.py
729738@@ -354,7 +354,10 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
@@ -738,7 +747,7 @@ index 83c6b59cec96..e89981286ed8 100644
738747 bh = max(B * H, 1) # NOTE: Handle B*h=0 case
739748 assert isinstance(bh, (int, sympy.Integer)), "B and H must be concrete integers"
740749 split_k = num_SM // bh * 2 # Each SM should at least get one block.
741- @@ -458 ,7 +461 ,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
750+ @@ -459 ,7 +462 ,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
742751 choices: list[Any] = []
743752 dtype = key.get_dtype()
744753 head_dim = V.graph.sizevars.guard_int(key.get_size()[-1])
@@ -749,7 +758,7 @@ index 83c6b59cec96..e89981286ed8 100644
749758
750759 # TODO: fix autotuning.
751760
752- @@ -505 ,7 +510 ,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
761+ @@ -506 ,7 +511 ,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
753762 )
754763 * gqa_shared_heads
755764 ),
@@ -759,7 +768,7 @@ index 83c6b59cec96..e89981286ed8 100644
759768 ),
760769 )
761770diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py
762- index eec1d055ddf7..f7a5aefb5cd1 100644
771+ index 57eaef9b4dbb..f5f414a68539 100644
763772--- a/torch/_inductor/template_heuristics.py
764773+++ b/torch/_inductor/template_heuristics.py
765774@@ -3,6 +3,7 @@
@@ -770,7 +779,7 @@ index eec1d055ddf7..f7a5aefb5cd1 100644
770779 from functools import partial
771780 from threading import Lock
772781 from typing import Any, Callable, Optional, TYPE_CHECKING
773- @@ -1203 ,6 +1204,97 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
782+ @@ -1208 ,6 +1209,114 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
774783 Placeholder child class for XPU specific overrides.
775784 """
776785
@@ -788,6 +797,23 @@ index eec1d055ddf7..f7a5aefb5cd1 100644
788797+ (torch.float16, 128): FlexConfig(128, 64, 1, 16),
789798+ (torch.float16, 256): FlexConfig(32, 64, 1, 4),
790799+ }
800+ + self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [
801+ + FlexConfig(32, 16, 2, 4),
802+ + FlexConfig(128, 64, 2, 16),
803+ + FlexConfig(128, 64, 2, 8),
804+ + FlexConfig(128, 32, 2, 16),
805+ + FlexConfig(128, 32, 2, 8),
806+ + ]
807+ + self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [
808+ + FlexDecodeConfig(32, 1, 2),
809+ + FlexDecodeConfig(32, 1, 1),
810+ + FlexDecodeConfig(32, 2, 2),
811+ + FlexDecodeConfig(32, 2, 1),
812+ + FlexDecodeConfig(64, 1, 2),
813+ + FlexDecodeConfig(64, 1, 1),
814+ + FlexDecodeConfig(64, 2, 2),
815+ + FlexDecodeConfig(64, 2, 1),
816+ + ]
791817+
792818+ def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]:
793819+ flex_attn_fwd_configs: list[FlexConfig] = []
@@ -899,7 +925,7 @@ index ec8027595e6f..f1d290467fb5 100644
899925 "FlexAttention is only supported on CUDA, CPU or HPU devices. "
900926 f"Found input tensors on {query.device.type} device."
901927diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py
902- index 01499280da8f..6a5951fde65d 100644
928+ index 528497ba5457..061c2a2eb819 100644
903929--- a/torch/testing/_internal/common_device_type.py
904930+++ b/torch/testing/_internal/common_device_type.py
905931@@ -1342,8 +1342,8 @@ def dep_fn(self, *args, **kwargs):
0 commit comments