@@ -33,7 +33,7 @@ index a0e7dce3df4d..9cd30e0178bf 100644
33
33
RUN bash ./install_xpu.sh && rm install_xpu.sh
34
34
35
35
diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py
36
- index e78cf68244ee..79cb9d102bdd 100644
36
+ index 8e4746212a0b..31c914399fae 100644
37
37
--- a/test/inductor/test_flex_attention.py
38
38
+++ b/test/inductor/test_flex_attention.py
39
39
@@ -42,20 +42,26 @@
@@ -372,39 +372,39 @@ index e78cf68244ee..79cb9d102bdd 100644
372
372
def test_captured_reduction(self, device, dtype):
373
373
scale = torch.randn((B, 8), device=device)
374
374
375
- @@ -2296 ,6 +2364 ,7 @@ def f(q, k, v):
375
+ @@ -2340 ,6 +2408 ,7 @@ def f(q, k, v):
376
376
@supported_platform
377
377
@dtypes(*device_configs["cpu"].dtypes)
378
378
@dtypesIfCUDA(*device_configs["cuda"].dtypes)
379
379
+ @dtypesIfXPU(*device_configs["xpu"].dtypes)
380
380
def test_njt_causal(self, device, dtype):
381
381
offsets = torch.tensor(
382
382
[0, 1024, 1024 + 512, S], device=device, dtype=torch.int32
383
- @@ -2358 ,6 +2427 ,7 @@ def bias_mod(score, batch, head, token_q, token_kv):
383
+ @@ -2402 ,6 +2471 ,7 @@ def bias_mod(score, batch, head, token_q, token_kv):
384
384
@common_utils.parametrize("score_mod", test_score_mods)
385
385
@dtypes(*device_configs["cpu"].dtypes)
386
386
@dtypesIfCUDA(*device_configs["cuda"].dtypes)
387
387
+ @dtypesIfXPU(*device_configs["xpu"].dtypes)
388
388
@common_utils.parametrize("head_dims", [(D, D // 2), (D // 2, D)])
389
389
def test_non_equal_head_dims(self, device, dtype, score_mod, head_dims):
390
390
qk_d, v_d = head_dims
391
- @@ -2451 ,6 +2521 ,7 @@ def causal(b, h, q_idx, kv_idx):
391
+ @@ -2495 ,6 +2565 ,7 @@ def causal(b, h, q_idx, kv_idx):
392
392
@common_utils.parametrize("head_dim", [17, 24, 94, 121])
393
393
@dtypes(*device_configs["cpu"].dtypes_fast)
394
394
@dtypesIfCUDA(*device_configs["cuda"].dtypes_fast)
395
395
+ @dtypesIfXPU(*device_configs["xpu"].dtypes_fast)
396
396
def test_non_pow_2_headdim(self, device, dtype, head_dim):
397
397
self.run_test(_rel_bias, dtype, device, B, H, S, head_dim, B, H, S, head_dim)
398
398
399
- @@ -2515 ,6 +2586 ,7 @@ def causal_constructor(S):
399
+ @@ -2559 ,6 +2630 ,7 @@ def causal_constructor(S):
400
400
@skip_on_cpu
401
401
@dtypes(*device_configs["cpu"].dtypes)
402
402
@dtypesIfCUDA(*device_configs["cuda"].dtypes)
403
403
+ @dtypesIfXPU(*device_configs["xpu"].dtypes)
404
404
@common_utils.parametrize("score_mod", [_identity, _causal])
405
405
def test_logsumexp_correctness(self, device, dtype, score_mod):
406
406
make_tensor = functools.partial(
407
- @@ -2971 ,7 +3043 ,7 @@ def test_flex_attention_backward_stride_ordering(
407
+ @@ -3015 ,7 +3087 ,7 @@ def test_flex_attention_backward_stride_ordering(
408
408
def test_non_contiguous_last_dim(self, device):
409
409
"""Test flex_attention with tensors having non contiguous last dimension."""
410
410
B, H, D = 4, 8, 64
@@ -413,7 +413,7 @@ index e78cf68244ee..79cb9d102bdd 100644
413
413
for S in [16, 64]:
414
414
415
415
def column_major_tensor():
416
- @@ -3803 ,7 +3875 ,7 @@ def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3
416
+ @@ -3847 ,7 +3919 ,7 @@ def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3
417
417
418
418
class mask_graph0(torch.nn.Module):
419
419
def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"):
@@ -422,7 +422,7 @@ index e78cf68244ee..79cb9d102bdd 100644
422
422
return full_default
423
423
""".replace( # noqa: B950
424
424
"GPU_TYPE", torch.device(device).type
425
- @@ -4091 ,9 +4163 ,9 @@ def flex_attn_fn(x):
425
+ @@ -4135 ,9 +4207 ,9 @@ def flex_attn_fn(x):
426
426
return output
427
427
428
428
flex_module = SacModule(hidden_size=512, num_heads=8, context_fn=context_fn).to(
@@ -434,7 +434,7 @@ index e78cf68244ee..79cb9d102bdd 100644
434
434
435
435
# Run without compilation
436
436
output_module = flex_module(x)
437
- @@ -4188 ,12 +4260 ,13 @@ def make_tensor():
437
+ @@ -4232 ,12 +4304 ,13 @@ def make_tensor():
438
438
439
439
@supported_platform
440
440
@skip_on_cpu
@@ -450,15 +450,15 @@ index e78cf68244ee..79cb9d102bdd 100644
450
450
dtype=torch.bfloat16,
451
451
)
452
452
query, key, value = make_tensor(), make_tensor(), make_tensor()
453
- @@ -4777 ,6 +4850 ,7 @@ def flex_attention_fn():
453
+ @@ -4821 ,6 +4894 ,7 @@ def flex_attention_fn():
454
454
)
455
455
456
456
@supported_platform
457
457
+ @skip_on_xpu
458
458
def test_create_is_cuda_graphable(self, device):
459
459
def mask_mod(b, h, q, kv):
460
460
return q >= kv
461
- @@ -4958 ,7 +5032 ,7 @@ def test_block_mask_operations_with_none_q_indices(self, device):
461
+ @@ -5002 ,7 +5076 ,7 @@ def test_block_mask_operations_with_none_q_indices(self, device):
462
462
self.assertIsNone(cpu_mask.q_indices)
463
463
464
464
@@ -467,15 +467,15 @@ index e78cf68244ee..79cb9d102bdd 100644
467
467
class TestPagedAttention(InductorTestCase):
468
468
def setUp(self):
469
469
super().setUp()
470
- @@ -5273 ,6 +5347 ,7 @@ def test_update(self, device):
470
+ @@ -5317 ,6 +5391 ,7 @@ def test_update(self, device):
471
471
@supported_platform
472
472
@dtypes(*device_configs["cpu"].dtypes)
473
473
@dtypesIfCUDA(*device_configs["cuda"].dtypes)
474
474
+ @dtypesIfXPU(*device_configs["xpu"].dtypes)
475
475
@common_utils.parametrize("score_mod", test_score_mods)
476
476
def test_paged_builtin_score_mods(
477
477
self, device, dtype: torch.dtype, score_mod: Callable
478
- @@ -5401 ,14 +5476 ,17 @@ def get_params(dtypes: list[torch.dtype]) -> list[Params]:
478
+ @@ -5445 ,14 +5520 ,17 @@ def get_params(dtypes: list[torch.dtype]) -> list[Params]:
479
479
480
480
481
481
supports_learnable_bias = unittest.skipUnless(
@@ -497,7 +497,16 @@ index e78cf68244ee..79cb9d102bdd 100644
497
497
class TestLearnableBiases(InductorTestCase):
498
498
def setUp(self):
499
499
super().setUp()
500
- @@ -6299,10 +6377,22 @@ def _test_learnable_bias_inner(
500
+ @@ -5505,7 +5583,7 @@ def _gold_check(self, eager, compiled, gold, tensor_name, fudge_factor=1.35):
501
+ def _check_outputs_and_grads(
502
+ self, out_eager, out_compiled, out_gold, tensors, names=None
503
+ ):
504
+ - backwards_grad = torch.randn_like(out_eager)
505
+ + backwards_grad = torch.randn_like(out_eager, device="cpu").to(out_eager.device)
506
+ grads_eager = torch.autograd.grad((out_eager,), tensors, backwards_grad)
507
+ grads_compiled = torch.autograd.grad((out_compiled,), tensors, backwards_grad)
508
+ grads_gold = torch.autograd.grad((out_gold,), tensors, backwards_grad)
509
+ @@ -6343,10 +6421,22 @@ def _test_learnable_bias_inner(
501
510
)
502
511
503
512
@@ -690,14 +699,14 @@ index b5ec59dc291c..777892a0ce2d 100644
690
699
if __name__ == "__main__":
691
700
from torch._inductor.test_case import run_tests
692
701
diff --git a/third_party/xpu.txt b/third_party/xpu.txt
693
- index f3cfe7166aa7..d13f6ae35d03 100644
702
+ index b84ebb55a901..42d53a213bd4 100644
694
703
--- a/third_party/xpu.txt
695
704
+++ b/third_party/xpu.txt
696
705
@@ -1 +1 @@
697
- - 3a9419c8bb6a98dd3e3cd473c36691fb4abeae40
698
- + 3f07dd52aac2e466c3c3efc15f88118f21428272
706
+ - 1f7a57f50745a429b7da10dddf2e366687659b87
707
+ + 2d6a5c68eca42378e0df9c92171f090eecdf5f96
699
708
diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py
700
- index 0553fd06755d..d094a48627fb 100644
709
+ index b6f5646bb57c..0cc877e75ebf 100644
701
710
--- a/torch/_inductor/kernel/flex/flex_attention.py
702
711
+++ b/torch/_inductor/kernel/flex/flex_attention.py
703
712
@@ -531,7 +531,9 @@ def flex_attention(
@@ -711,7 +720,7 @@ index 0553fd06755d..d094a48627fb 100644
711
720
712
721
# Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
713
722
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
714
- @@ -1653 ,7 +1655 ,9 @@ def flex_attention_backward(*args, **kwargs):
723
+ @@ -1655 ,7 +1657 ,9 @@ def flex_attention_backward(*args, **kwargs):
715
724
716
725
dtype = query.get_dtype()
717
726
head_dim = V.graph.sizevars.guard_int(query.get_size()[-1])
@@ -723,7 +732,7 @@ index 0553fd06755d..d094a48627fb 100644
723
732
# Default config for warp specialization
724
733
num_consumer_groups, num_buffers_warp_spec = 0, 0
725
734
diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py
726
- index 83c6b59cec96..e89981286ed8 100644
735
+ index 7f92fbc705a5..c5868cb21bae 100644
727
736
--- a/torch/_inductor/kernel/flex/flex_decoding.py
728
737
+++ b/torch/_inductor/kernel/flex/flex_decoding.py
729
738
@@ -354,7 +354,10 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
@@ -738,7 +747,7 @@ index 83c6b59cec96..e89981286ed8 100644
738
747
bh = max(B * H, 1) # NOTE: Handle B*h=0 case
739
748
assert isinstance(bh, (int, sympy.Integer)), "B and H must be concrete integers"
740
749
split_k = num_SM // bh * 2 # Each SM should at least get one block.
741
- @@ -458 ,7 +461 ,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
750
+ @@ -459 ,7 +462 ,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
742
751
choices: list[Any] = []
743
752
dtype = key.get_dtype()
744
753
head_dim = V.graph.sizevars.guard_int(key.get_size()[-1])
@@ -749,7 +758,7 @@ index 83c6b59cec96..e89981286ed8 100644
749
758
750
759
# TODO: fix autotuning.
751
760
752
- @@ -505 ,7 +510 ,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
761
+ @@ -506 ,7 +511 ,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
753
762
)
754
763
* gqa_shared_heads
755
764
),
@@ -759,7 +768,7 @@ index 83c6b59cec96..e89981286ed8 100644
759
768
),
760
769
)
761
770
diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py
762
- index eec1d055ddf7..f7a5aefb5cd1 100644
771
+ index 57eaef9b4dbb..f5f414a68539 100644
763
772
--- a/torch/_inductor/template_heuristics.py
764
773
+++ b/torch/_inductor/template_heuristics.py
765
774
@@ -3,6 +3,7 @@
@@ -770,7 +779,7 @@ index eec1d055ddf7..f7a5aefb5cd1 100644
770
779
from functools import partial
771
780
from threading import Lock
772
781
from typing import Any, Callable, Optional, TYPE_CHECKING
773
- @@ -1203 ,6 +1204,97 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
782
+ @@ -1208 ,6 +1209,114 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
774
783
Placeholder child class for XPU specific overrides.
775
784
"""
776
785
@@ -788,6 +797,23 @@ index eec1d055ddf7..f7a5aefb5cd1 100644
788
797
+ (torch.float16, 128): FlexConfig(128, 64, 1, 16),
789
798
+ (torch.float16, 256): FlexConfig(32, 64, 1, 4),
790
799
+ }
800
+ + self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [
801
+ + FlexConfig(32, 16, 2, 4),
802
+ + FlexConfig(128, 64, 2, 16),
803
+ + FlexConfig(128, 64, 2, 8),
804
+ + FlexConfig(128, 32, 2, 16),
805
+ + FlexConfig(128, 32, 2, 8),
806
+ + ]
807
+ + self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [
808
+ + FlexDecodeConfig(32, 1, 2),
809
+ + FlexDecodeConfig(32, 1, 1),
810
+ + FlexDecodeConfig(32, 2, 2),
811
+ + FlexDecodeConfig(32, 2, 1),
812
+ + FlexDecodeConfig(64, 1, 2),
813
+ + FlexDecodeConfig(64, 1, 1),
814
+ + FlexDecodeConfig(64, 2, 2),
815
+ + FlexDecodeConfig(64, 2, 1),
816
+ + ]
791
817
+
792
818
+ def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]:
793
819
+ flex_attn_fwd_configs: list[FlexConfig] = []
@@ -899,7 +925,7 @@ index ec8027595e6f..f1d290467fb5 100644
899
925
"FlexAttention is only supported on CUDA, CPU or HPU devices. "
900
926
f"Found input tensors on {query.device.type} device."
901
927
diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py
902
- index 01499280da8f..6a5951fde65d 100644
928
+ index 528497ba5457..061c2a2eb819 100644
903
929
--- a/torch/testing/_internal/common_device_type.py
904
930
+++ b/torch/testing/_internal/common_device_type.py
905
931
@@ -1342,8 +1342,8 @@ def dep_fn(self, *args, **kwargs):
0 commit comments