Skip to content

Commit 3caf1e9

Browse files
authored
Update PyTorch pin (#4859)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent d8e39c3 commit 3caf1e9

File tree

2 files changed

+52
-26
lines changed

2 files changed

+52
-26
lines changed

.github/pins/pytorch.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
83e2ea8135c42fa826c3d751a04f60259e97147f
1+
3f1636ebef9b45e8a3cb0eb20d327ee6acb74be0

scripts/patch/flex_attn_143553.patch

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ index a0e7dce3df4d..9cd30e0178bf 100644
3333
RUN bash ./install_xpu.sh && rm install_xpu.sh
3434

3535
diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py
36-
index e78cf68244ee..79cb9d102bdd 100644
36+
index 8e4746212a0b..31c914399fae 100644
3737
--- a/test/inductor/test_flex_attention.py
3838
+++ b/test/inductor/test_flex_attention.py
3939
@@ -42,20 +42,26 @@
@@ -372,39 +372,39 @@ index e78cf68244ee..79cb9d102bdd 100644
372372
def test_captured_reduction(self, device, dtype):
373373
scale = torch.randn((B, 8), device=device)
374374

375-
@@ -2296,6 +2364,7 @@ def f(q, k, v):
375+
@@ -2340,6 +2408,7 @@ def f(q, k, v):
376376
@supported_platform
377377
@dtypes(*device_configs["cpu"].dtypes)
378378
@dtypesIfCUDA(*device_configs["cuda"].dtypes)
379379
+ @dtypesIfXPU(*device_configs["xpu"].dtypes)
380380
def test_njt_causal(self, device, dtype):
381381
offsets = torch.tensor(
382382
[0, 1024, 1024 + 512, S], device=device, dtype=torch.int32
383-
@@ -2358,6 +2427,7 @@ def bias_mod(score, batch, head, token_q, token_kv):
383+
@@ -2402,6 +2471,7 @@ def bias_mod(score, batch, head, token_q, token_kv):
384384
@common_utils.parametrize("score_mod", test_score_mods)
385385
@dtypes(*device_configs["cpu"].dtypes)
386386
@dtypesIfCUDA(*device_configs["cuda"].dtypes)
387387
+ @dtypesIfXPU(*device_configs["xpu"].dtypes)
388388
@common_utils.parametrize("head_dims", [(D, D // 2), (D // 2, D)])
389389
def test_non_equal_head_dims(self, device, dtype, score_mod, head_dims):
390390
qk_d, v_d = head_dims
391-
@@ -2451,6 +2521,7 @@ def causal(b, h, q_idx, kv_idx):
391+
@@ -2495,6 +2565,7 @@ def causal(b, h, q_idx, kv_idx):
392392
@common_utils.parametrize("head_dim", [17, 24, 94, 121])
393393
@dtypes(*device_configs["cpu"].dtypes_fast)
394394
@dtypesIfCUDA(*device_configs["cuda"].dtypes_fast)
395395
+ @dtypesIfXPU(*device_configs["xpu"].dtypes_fast)
396396
def test_non_pow_2_headdim(self, device, dtype, head_dim):
397397
self.run_test(_rel_bias, dtype, device, B, H, S, head_dim, B, H, S, head_dim)
398398

399-
@@ -2515,6 +2586,7 @@ def causal_constructor(S):
399+
@@ -2559,6 +2630,7 @@ def causal_constructor(S):
400400
@skip_on_cpu
401401
@dtypes(*device_configs["cpu"].dtypes)
402402
@dtypesIfCUDA(*device_configs["cuda"].dtypes)
403403
+ @dtypesIfXPU(*device_configs["xpu"].dtypes)
404404
@common_utils.parametrize("score_mod", [_identity, _causal])
405405
def test_logsumexp_correctness(self, device, dtype, score_mod):
406406
make_tensor = functools.partial(
407-
@@ -2971,7 +3043,7 @@ def test_flex_attention_backward_stride_ordering(
407+
@@ -3015,7 +3087,7 @@ def test_flex_attention_backward_stride_ordering(
408408
def test_non_contiguous_last_dim(self, device):
409409
"""Test flex_attention with tensors having non contiguous last dimension."""
410410
B, H, D = 4, 8, 64
@@ -413,7 +413,7 @@ index e78cf68244ee..79cb9d102bdd 100644
413413
for S in [16, 64]:
414414

415415
def column_major_tensor():
416-
@@ -3803,7 +3875,7 @@ def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3
416+
@@ -3847,7 +3919,7 @@ def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3
417417

418418
class mask_graph0(torch.nn.Module):
419419
def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"):
@@ -422,7 +422,7 @@ index e78cf68244ee..79cb9d102bdd 100644
422422
return full_default
423423
""".replace( # noqa: B950
424424
"GPU_TYPE", torch.device(device).type
425-
@@ -4091,9 +4163,9 @@ def flex_attn_fn(x):
425+
@@ -4135,9 +4207,9 @@ def flex_attn_fn(x):
426426
return output
427427

428428
flex_module = SacModule(hidden_size=512, num_heads=8, context_fn=context_fn).to(
@@ -434,7 +434,7 @@ index e78cf68244ee..79cb9d102bdd 100644
434434

435435
# Run without compilation
436436
output_module = flex_module(x)
437-
@@ -4188,12 +4260,13 @@ def make_tensor():
437+
@@ -4232,12 +4304,13 @@ def make_tensor():
438438

439439
@supported_platform
440440
@skip_on_cpu
@@ -450,15 +450,15 @@ index e78cf68244ee..79cb9d102bdd 100644
450450
dtype=torch.bfloat16,
451451
)
452452
query, key, value = make_tensor(), make_tensor(), make_tensor()
453-
@@ -4777,6 +4850,7 @@ def flex_attention_fn():
453+
@@ -4821,6 +4894,7 @@ def flex_attention_fn():
454454
)
455455

456456
@supported_platform
457457
+ @skip_on_xpu
458458
def test_create_is_cuda_graphable(self, device):
459459
def mask_mod(b, h, q, kv):
460460
return q >= kv
461-
@@ -4958,7 +5032,7 @@ def test_block_mask_operations_with_none_q_indices(self, device):
461+
@@ -5002,7 +5076,7 @@ def test_block_mask_operations_with_none_q_indices(self, device):
462462
self.assertIsNone(cpu_mask.q_indices)
463463

464464

@@ -467,15 +467,15 @@ index e78cf68244ee..79cb9d102bdd 100644
467467
class TestPagedAttention(InductorTestCase):
468468
def setUp(self):
469469
super().setUp()
470-
@@ -5273,6 +5347,7 @@ def test_update(self, device):
470+
@@ -5317,6 +5391,7 @@ def test_update(self, device):
471471
@supported_platform
472472
@dtypes(*device_configs["cpu"].dtypes)
473473
@dtypesIfCUDA(*device_configs["cuda"].dtypes)
474474
+ @dtypesIfXPU(*device_configs["xpu"].dtypes)
475475
@common_utils.parametrize("score_mod", test_score_mods)
476476
def test_paged_builtin_score_mods(
477477
self, device, dtype: torch.dtype, score_mod: Callable
478-
@@ -5401,14 +5476,17 @@ def get_params(dtypes: list[torch.dtype]) -> list[Params]:
478+
@@ -5445,14 +5520,17 @@ def get_params(dtypes: list[torch.dtype]) -> list[Params]:
479479

480480

481481
supports_learnable_bias = unittest.skipUnless(
@@ -497,7 +497,16 @@ index e78cf68244ee..79cb9d102bdd 100644
497497
class TestLearnableBiases(InductorTestCase):
498498
def setUp(self):
499499
super().setUp()
500-
@@ -6299,10 +6377,22 @@ def _test_learnable_bias_inner(
500+
@@ -5505,7 +5583,7 @@ def _gold_check(self, eager, compiled, gold, tensor_name, fudge_factor=1.35):
501+
def _check_outputs_and_grads(
502+
self, out_eager, out_compiled, out_gold, tensors, names=None
503+
):
504+
- backwards_grad = torch.randn_like(out_eager)
505+
+ backwards_grad = torch.randn_like(out_eager, device="cpu").to(out_eager.device)
506+
grads_eager = torch.autograd.grad((out_eager,), tensors, backwards_grad)
507+
grads_compiled = torch.autograd.grad((out_compiled,), tensors, backwards_grad)
508+
grads_gold = torch.autograd.grad((out_gold,), tensors, backwards_grad)
509+
@@ -6343,10 +6421,22 @@ def _test_learnable_bias_inner(
501510
)
502511

503512

@@ -690,14 +699,14 @@ index b5ec59dc291c..777892a0ce2d 100644
690699
if __name__ == "__main__":
691700
from torch._inductor.test_case import run_tests
692701
diff --git a/third_party/xpu.txt b/third_party/xpu.txt
693-
index f3cfe7166aa7..d13f6ae35d03 100644
702+
index b84ebb55a901..42d53a213bd4 100644
694703
--- a/third_party/xpu.txt
695704
+++ b/third_party/xpu.txt
696705
@@ -1 +1 @@
697-
-3a9419c8bb6a98dd3e3cd473c36691fb4abeae40
698-
+3f07dd52aac2e466c3c3efc15f88118f21428272
706+
-1f7a57f50745a429b7da10dddf2e366687659b87
707+
+2d6a5c68eca42378e0df9c92171f090eecdf5f96
699708
diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py
700-
index 0553fd06755d..d094a48627fb 100644
709+
index b6f5646bb57c..0cc877e75ebf 100644
701710
--- a/torch/_inductor/kernel/flex/flex_attention.py
702711
+++ b/torch/_inductor/kernel/flex/flex_attention.py
703712
@@ -531,7 +531,9 @@ def flex_attention(
@@ -711,7 +720,7 @@ index 0553fd06755d..d094a48627fb 100644
711720

712721
# Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
713722
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
714-
@@ -1653,7 +1655,9 @@ def flex_attention_backward(*args, **kwargs):
723+
@@ -1655,7 +1657,9 @@ def flex_attention_backward(*args, **kwargs):
715724

716725
dtype = query.get_dtype()
717726
head_dim = V.graph.sizevars.guard_int(query.get_size()[-1])
@@ -723,7 +732,7 @@ index 0553fd06755d..d094a48627fb 100644
723732
# Default config for warp specialization
724733
num_consumer_groups, num_buffers_warp_spec = 0, 0
725734
diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py
726-
index 83c6b59cec96..e89981286ed8 100644
735+
index 7f92fbc705a5..c5868cb21bae 100644
727736
--- a/torch/_inductor/kernel/flex/flex_decoding.py
728737
+++ b/torch/_inductor/kernel/flex/flex_decoding.py
729738
@@ -354,7 +354,10 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
@@ -738,7 +747,7 @@ index 83c6b59cec96..e89981286ed8 100644
738747
bh = max(B * H, 1) # NOTE: Handle B*h=0 case
739748
assert isinstance(bh, (int, sympy.Integer)), "B and H must be concrete integers"
740749
split_k = num_SM // bh * 2 # Each SM should at least get one block.
741-
@@ -458,7 +461,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
750+
@@ -459,7 +462,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
742751
choices: list[Any] = []
743752
dtype = key.get_dtype()
744753
head_dim = V.graph.sizevars.guard_int(key.get_size()[-1])
@@ -749,7 +758,7 @@ index 83c6b59cec96..e89981286ed8 100644
749758

750759
# TODO: fix autotuning.
751760

752-
@@ -505,7 +510,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
761+
@@ -506,7 +511,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
753762
)
754763
* gqa_shared_heads
755764
),
@@ -759,7 +768,7 @@ index 83c6b59cec96..e89981286ed8 100644
759768
),
760769
)
761770
diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py
762-
index eec1d055ddf7..f7a5aefb5cd1 100644
771+
index 57eaef9b4dbb..f5f414a68539 100644
763772
--- a/torch/_inductor/template_heuristics.py
764773
+++ b/torch/_inductor/template_heuristics.py
765774
@@ -3,6 +3,7 @@
@@ -770,7 +779,7 @@ index eec1d055ddf7..f7a5aefb5cd1 100644
770779
from functools import partial
771780
from threading import Lock
772781
from typing import Any, Callable, Optional, TYPE_CHECKING
773-
@@ -1203,6 +1204,97 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
782+
@@ -1208,6 +1209,114 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
774783
Placeholder child class for XPU specific overrides.
775784
"""
776785

@@ -788,6 +797,23 @@ index eec1d055ddf7..f7a5aefb5cd1 100644
788797
+ (torch.float16, 128): FlexConfig(128, 64, 1, 16),
789798
+ (torch.float16, 256): FlexConfig(32, 64, 1, 4),
790799
+ }
800+
+ self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [
801+
+ FlexConfig(32, 16, 2, 4),
802+
+ FlexConfig(128, 64, 2, 16),
803+
+ FlexConfig(128, 64, 2, 8),
804+
+ FlexConfig(128, 32, 2, 16),
805+
+ FlexConfig(128, 32, 2, 8),
806+
+ ]
807+
+ self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [
808+
+ FlexDecodeConfig(32, 1, 2),
809+
+ FlexDecodeConfig(32, 1, 1),
810+
+ FlexDecodeConfig(32, 2, 2),
811+
+ FlexDecodeConfig(32, 2, 1),
812+
+ FlexDecodeConfig(64, 1, 2),
813+
+ FlexDecodeConfig(64, 1, 1),
814+
+ FlexDecodeConfig(64, 2, 2),
815+
+ FlexDecodeConfig(64, 2, 1),
816+
+ ]
791817
+
792818
+ def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]:
793819
+ flex_attn_fwd_configs: list[FlexConfig] = []
@@ -899,7 +925,7 @@ index ec8027595e6f..f1d290467fb5 100644
899925
"FlexAttention is only supported on CUDA, CPU or HPU devices. "
900926
f"Found input tensors on {query.device.type} device."
901927
diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py
902-
index 01499280da8f..6a5951fde65d 100644
928+
index 528497ba5457..061c2a2eb819 100644
903929
--- a/torch/testing/_internal/common_device_type.py
904930
+++ b/torch/testing/_internal/common_device_type.py
905931
@@ -1342,8 +1342,8 @@ def dep_fn(self, *args, **kwargs):

0 commit comments

Comments
 (0)