Skip to content

Commit 4d11eb7

Browse files
authored
Update PyTorch pin (#4826)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 9d65e08 commit 4d11eb7

File tree

5 files changed

+56
-53
lines changed

5 files changed

+56
-53
lines changed

.github/pins/pytorch.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
6d071bd65de9bdc354f32adf67e00d6e13475e76
1+
83e2ea8135c42fa826c3d751a04f60259e97147f

scripts/patch-pytorch.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,3 @@ apply_patch ./patch/pytorch_fp64.patch
4040
apply_patch ./patch/pytorch_global_scratch.patch
4141
apply_patch ./patch/test_compile_subprocess.patch
4242
apply_patch ./patch/flex_decoding.patch
43-
apply_patch ./patch/flex_attn_bwd_num_stage_1.patch

scripts/patch/flex_attn_143553.patch

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ index e78cf68244ee..79cb9d102bdd 100644
525525
if __name__ == "__main__":
526526
from torch._inductor.test_case import run_tests
527527
diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py
528-
index 5af7b284f757..667915c69507 100644
528+
index b5ec59dc291c..777892a0ce2d 100644
529529
--- a/test/inductor/test_flex_decoding.py
530530
+++ b/test/inductor/test_flex_decoding.py
531531
@@ -27,6 +27,7 @@
@@ -598,7 +598,7 @@ index 5af7b284f757..667915c69507 100644
598598
requires_grad=True,
599599
)
600600
q, k, v, backward_grad = make_q(), make_kv(), make_kv(), make_q()
601-
@@ -998,12 +1007,12 @@ def mask_mod(b, h, q, kv):
601+
@@ -999,12 +1008,12 @@ def mask_mod(b, h, q, kv):
602602

603603
@supported_platform
604604
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
@@ -615,7 +615,7 @@ index 5af7b284f757..667915c69507 100644
615615

616616
def score_mod(score, b, h, q, kv):
617617
return score + offset_kv[kv] + offset_q[q]
618-
@@ -1011,8 +1020,14 @@ def score_mod(score, b, h, q, kv):
618+
@@ -1012,8 +1021,14 @@ def score_mod(score, b, h, q, kv):
619619
def mask_mod(b, h, q, kv):
620620
return kv >= q + offset_tensor
621621

@@ -632,7 +632,7 @@ index 5af7b284f757..667915c69507 100644
632632

633633
@supported_platform
634634
@common_utils.parametrize("dtype", test_dtypes_fast)
635-
@@ -1677,19 +1692,19 @@ def mask_mod(b, h, q, kv):
635+
@@ -1679,19 +1694,19 @@ def mask_mod(b, h, q, kv):
636636
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
637637
@common_utils.parametrize("dtype", test_dtypes)
638638
@common_utils.parametrize("score_mod", [_identity, _causal])
@@ -655,7 +655,7 @@ index 5af7b284f757..667915c69507 100644
655655
requires_grad=True,
656656
)
657657
q, k, v = make_q(), make_kv(), make_kv()
658-
@@ -1729,19 +1744,19 @@ def eager_sdpa_hop(q, k, v, score_mod):
658+
@@ -1731,19 +1746,19 @@ def eager_sdpa_hop(q, k, v, score_mod):
659659

660660
@supported_platform
661661
@unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported")
@@ -678,7 +678,7 @@ index 5af7b284f757..667915c69507 100644
678678
requires_grad=True,
679679
)
680680

681-
@@ -1993,7 +2008,9 @@ def causal_mask(b, h, q, kv):
681+
@@ -1995,7 +2010,9 @@ def causal_mask(b, h, q, kv):
682682
self._check_equal(golden_outs, ref_outs, paged_out, fudge_factor, "Out")
683683

684684

@@ -689,11 +689,18 @@ index 5af7b284f757..667915c69507 100644
689689

690690
if __name__ == "__main__":
691691
from torch._inductor.test_case import run_tests
692-
diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py
693-
index e471332afe71..ced92fae6229 100644
694-
--- a/torch/_inductor/kernel/flex_attention.py
695-
+++ b/torch/_inductor/kernel/flex_attention.py
696-
@@ -1445,7 +1445,9 @@ def flex_attention(
692+
diff --git a/third_party/xpu.txt b/third_party/xpu.txt
693+
index f3cfe7166aa7..d13f6ae35d03 100644
694+
--- a/third_party/xpu.txt
695+
+++ b/third_party/xpu.txt
696+
@@ -1 +1 @@
697+
-3a9419c8bb6a98dd3e3cd473c36691fb4abeae40
698+
+3f07dd52aac2e466c3c3efc15f88118f21428272
699+
diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py
700+
index 0553fd06755d..d094a48627fb 100644
701+
--- a/torch/_inductor/kernel/flex/flex_attention.py
702+
+++ b/torch/_inductor/kernel/flex/flex_attention.py
703+
@@ -531,7 +531,9 @@ def flex_attention(
697704

698705
dtype = query.get_dtype()
699706
head_dim = V.graph.sizevars.guard_int(query.get_size()[-1])
@@ -704,7 +711,7 @@ index e471332afe71..ced92fae6229 100644
704711

705712
# Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
706713
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
707-
@@ -2567,7 +2569,9 @@ def flex_attention_backward(*args, **kwargs):
714+
@@ -1653,7 +1655,9 @@ def flex_attention_backward(*args, **kwargs):
708715

709716
dtype = query.get_dtype()
710717
head_dim = V.graph.sizevars.guard_int(query.get_size()[-1])
@@ -715,11 +722,11 @@ index e471332afe71..ced92fae6229 100644
715722

716723
# Default config for warp specialization
717724
num_consumer_groups, num_buffers_warp_spec = 0, 0
718-
diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py
719-
index 7e0aef981856..628bfc6419be 100644
720-
--- a/torch/_inductor/kernel/flex_decoding.py
721-
+++ b/torch/_inductor/kernel/flex_decoding.py
722-
@@ -310,7 +310,10 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
725+
diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py
726+
index 83c6b59cec96..e89981286ed8 100644
727+
--- a/torch/_inductor/kernel/flex/flex_decoding.py
728+
+++ b/torch/_inductor/kernel/flex/flex_decoding.py
729+
@@ -354,7 +354,10 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
723730

724731

725732
def get_split_k(B: int, H: int, Mk: int) -> int:
@@ -731,7 +738,7 @@ index 7e0aef981856..628bfc6419be 100644
731738
bh = max(B * H, 1) # NOTE: Handle B*h=0 case
732739
assert isinstance(bh, (int, sympy.Integer)), "B and H must be concrete integers"
733740
split_k = num_SM // bh * 2 # Each SM should at least get one block.
734-
@@ -415,7 +418,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
741+
@@ -458,7 +461,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
735742
choices: list[Any] = []
736743
dtype = key.get_dtype()
737744
head_dim = V.graph.sizevars.guard_int(key.get_size()[-1])
@@ -742,24 +749,31 @@ index 7e0aef981856..628bfc6419be 100644
742749

743750
# TODO: fix autotuning.
744751

745-
@@ -462,7 +467,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
752+
@@ -505,7 +510,7 @@ def create_flex_decoding_kernel(*args, **kwargs):
746753
)
747754
* gqa_shared_heads
748755
),
749756
- 16,
750-
+ float('-inf') if torch.xpu.is_available() else 16,
757+
+ 1 if torch.xpu.is_available() else 16,
751758
)
752759
),
753760
)
754761
diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py
755-
index 65a6851192a0..3a53f0eed52e 100644
762+
index eec1d055ddf7..f7a5aefb5cd1 100644
756763
--- a/torch/_inductor/template_heuristics.py
757764
+++ b/torch/_inductor/template_heuristics.py
758-
@@ -1201,3 +1201,87 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
759-
"""
765+
@@ -3,6 +3,7 @@
766+
import dataclasses
767+
import itertools
768+
import math
769+
+import os
770+
from functools import partial
771+
from threading import Lock
772+
from typing import Any, Callable, Optional, TYPE_CHECKING
773+
@@ -1203,6 +1204,97 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
760774
Placeholder child class for XPU specific overrides.
761775
"""
762-
+
776+
763777
+ def __init__(self) -> None:
764778
+ super().__init__()
765779
+
@@ -804,6 +818,9 @@ index 65a6851192a0..3a53f0eed52e 100644
804818
+
805819
+ def get_flex_attn_bwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]:
806820
+ flex_attn_bwd_configs: list[FlexConfig] = []
821+
+ TRITON_LESS_FLEX_ATTN_BWD_CONFIGS = os.getenv(
822+
+ "TRITON_LESS_FLEX_ATTN_BWD_CONFIGS", "0"
823+
+ ).lower() in {"true", "1", "t", "y", "yes", "on"}
807824
+
808825
+ if config.max_autotune:
809826
+ if config.max_autotune_flex_search_space == "EXHAUSTIVE":
@@ -825,6 +842,10 @@ index 65a6851192a0..3a53f0eed52e 100644
825842
+ if default_config not in flex_attn_bwd_configs:
826843
+ flex_attn_bwd_configs.append(default_config)
827844
+
845+
+ if TRITON_LESS_FLEX_ATTN_BWD_CONFIGS:
846+
+ flex_attn_bwd_configs = list(
847+
+ filter(lambda c: c.num_stages == 1, flex_attn_bwd_configs)
848+
+ )
828849
+ return flex_attn_bwd_configs
829850
+
830851
+ def get_flex_decode_configs(
@@ -843,8 +864,12 @@ index 65a6851192a0..3a53f0eed52e 100644
843864
+ flex_decode_configs.append(default_config)
844865
+
845866
+ return flex_decode_configs
867+
+
868+
869+
class MTIAConfigHeuristic(BaseConfigHeuristic):
870+
"""
846871
diff --git a/torch/_ops.py b/torch/_ops.py
847-
index fecfebaeaa53..8fac24a8579c 100644
872+
index 83a5dc0e57a5..b351aa17dfa7 100644
848873
--- a/torch/_ops.py
849874
+++ b/torch/_ops.py
850875
@@ -267,6 +267,7 @@ def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type]

scripts/patch/flex_attn_bwd_num_stage_1.patch

Lines changed: 0 additions & 21 deletions
This file was deleted.

scripts/patch/flex_decoding.patch

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py
2-
index 628bfc6419b..e86aca59db6 100644
3-
--- a/torch/_inductor/kernel/flex_decoding.py
4-
+++ b/torch/_inductor/kernel/flex_decoding.py
1+
diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py
2+
index 83c6b59cec96..e89981286ed8 100644
3+
--- a/torch/_inductor/kernel/flex/flex_decoding.py
4+
+++ b/torch/_inductor/kernel/flex/flex_decoding.py
55
@@ -459,15 +459,12 @@ def create_flex_decoding_kernel(*args, **kwargs):
66
# m
77
# if V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 0))
@@ -14,7 +14,7 @@ index 628bfc6419b..e86aca59db6 100644
1414
- )
1515
- * gqa_shared_heads
1616
- ),
17-
- float('-inf') if torch.xpu.is_available() else 16,
17+
- 1 if torch.xpu.is_available() else 16,
1818
+ next_power_of_2(
1919
+ V.graph.sizevars.size_hint(
2020
+ seq_len_q,

0 commit comments

Comments
 (0)