Skip to content

Commit b95b34f

Browse files
authored
Update PyTorch pin (#4729)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 3ddfd2a commit b95b34f

File tree

3 files changed

+77
-18
lines changed

3 files changed

+77
-18
lines changed

.github/pins/pytorch.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
815545f2dd6ade563cb1263f8bb7813f355edb2e
1+
1f57e0e04da9d334e238cec346f7ae3667bed9d1

scripts/patch-pytorch.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,5 @@ echo "Applying PyTorch patches in $REPO_ROOT"
3636

3737
# put your patch applies here
3838
apply_patch ./patch/flex_attn_143553.patch
39-
# trigger build
4039
apply_patch pytorch_fp64.patch
4140
apply_patch ./patch/pytorch_global_scratch.patch

scripts/patch/flex_attn_143553.patch

Lines changed: 76 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
diff --git a/.ci/docker/common/install_xpu.sh b/.ci/docker/common/install_xpu.sh
2-
index 51e9df623d5d1a..647e77f6d17bdc 100644
2+
index ecbbb8ccccf897..6349a7c6829c77 100644
33
--- a/.ci/docker/common/install_xpu.sh
44
+++ b/.ci/docker/common/install_xpu.sh
55
@@ -35,12 +35,12 @@ function install_ubuntu() {
@@ -33,7 +33,7 @@ index a0e7dce3df4d55..9cd30e0178bf92 100644
3333
RUN bash ./install_xpu.sh && rm install_xpu.sh
3434

3535
diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py
36-
index 4d14555800c8c4..ab40d19d2ff5ee 100644
36+
index fa6400dd9c2724..b72d8e9021fc58 100644
3737
--- a/test/inductor/test_flex_attention.py
3838
+++ b/test/inductor/test_flex_attention.py
3939
@@ -41,20 +41,26 @@
@@ -450,32 +450,32 @@ index 4d14555800c8c4..ab40d19d2ff5ee 100644
450450
dtype=torch.bfloat16,
451451
)
452452
query, key, value = make_tensor(), make_tensor(), make_tensor()
453-
@@ -4730,6 +4803,7 @@ def flex_attention_fn():
453+
@@ -4722,6 +4795,7 @@ def flex_attention_fn():
454454
)
455455

456456
@supported_platform
457457
+ @skip_on_xpu
458458
def test_create_is_cuda_graphable(self, device):
459459
def mask_mod(b, h, q, kv):
460460
return q >= kv
461-
@@ -4771,7 +4845,7 @@ def create_inputs(S):
462-
flex_attention_call(*create_inputs(1024), block_mask=block_mask)
461+
@@ -4903,7 +4977,7 @@ def test_block_mask_operations_with_none_q_indices(self, device):
462+
self.assertIsNone(cpu_mask.q_indices)
463463

464464

465465
-@large_tensor_test_class("2GB", device="cuda")
466466
+@large_tensor_test_class("2GB", device=test_device[0])
467467
class TestPagedAttention(InductorTestCase):
468468
def setUp(self):
469469
super().setUp()
470-
@@ -5086,6 +5160,7 @@ def test_update(self, device):
470+
@@ -5218,6 +5292,7 @@ def test_update(self, device):
471471
@supported_platform
472472
@dtypes(*device_configs["cpu"].dtypes)
473473
@dtypesIfCUDA(*device_configs["cuda"].dtypes)
474474
+ @dtypesIfXPU(*device_configs["xpu"].dtypes)
475475
@common_utils.parametrize("score_mod", test_score_mods)
476476
def test_paged_builtin_score_mods(
477477
self, device, dtype: torch.dtype, score_mod: Callable
478-
@@ -5214,14 +5289,17 @@ def get_params(dtypes: list[torch.dtype]) -> list[Params]:
478+
@@ -5346,14 +5421,17 @@ def get_params(dtypes: list[torch.dtype]) -> list[Params]:
479479

480480

481481
supports_learnable_bias = unittest.skipUnless(
@@ -497,7 +497,7 @@ index 4d14555800c8c4..ab40d19d2ff5ee 100644
497497
class TestLearnableBiases(InductorTestCase):
498498
def setUp(self):
499499
super().setUp()
500-
@@ -6112,10 +6190,22 @@ def _test_learnable_bias_inner(
500+
@@ -6244,10 +6322,22 @@ def _test_learnable_bias_inner(
501501
)
502502

503503

@@ -691,7 +691,7 @@ index 3b4905fc356168..3a165e5fff2eda 100644
691691
if __name__ == "__main__":
692692
from torch._inductor.test_case import run_tests
693693
diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py
694-
index 99e869dc8fdb71..1426f000d191d2 100644
694+
index 9a7507631cc490..1a761bf3833e56 100644
695695
--- a/torch/_inductor/kernel/flex_attention.py
696696
+++ b/torch/_inductor/kernel/flex_attention.py
697697
@@ -1441,7 +1441,9 @@ def flex_attention(
@@ -705,7 +705,7 @@ index 99e869dc8fdb71..1426f000d191d2 100644
705705

706706
# Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
707707
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
708-
@@ -2557,7 +2559,9 @@ def flex_attention_backward(*args, **kwargs):
708+
@@ -2560,7 +2562,9 @@ def flex_attention_backward(*args, **kwargs):
709709

710710
dtype = query.get_dtype()
711711
head_dim = V.graph.sizevars.guard_int(query.get_size()[-1])
@@ -717,7 +717,7 @@ index 99e869dc8fdb71..1426f000d191d2 100644
717717
# Default config for warp specialization
718718
num_consumer_groups, num_buffers_warp_spec = 0, 0
719719
diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py
720-
index 7e0aef98185603..343086e5b2d16a 100644
720+
index 7e0aef98185603..6c8bd4b593ae38 100644
721721
--- a/torch/_inductor/kernel/flex_decoding.py
722722
+++ b/torch/_inductor/kernel/flex_decoding.py
723723
@@ -310,7 +310,10 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
@@ -743,11 +743,71 @@ index 7e0aef98185603..343086e5b2d16a 100644
743743

744744
# TODO: fix autotuning.
745745

746+
@@ -448,24 +453,41 @@ def create_flex_decoding_kernel(*args, **kwargs):
747+
748+
set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars)
749+
750+
- kernel_options.setdefault(
751+
- "BLOCK_M",
752+
- (
753+
- # m
754+
- # if V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 0))
755+
- # else # Always use a BLOCK_M > 16 before Triton fix https://github.com/triton-lang/triton/pull/4061 is in pin
756+
- max(
757+
- next_power_of_2(
758+
- V.graph.sizevars.size_hint(
759+
- seq_len_q,
760+
- fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
761+
- )
762+
- * gqa_shared_heads
763+
- ),
764+
- 16,
765+
- )
766+
- ),
767+
- )
768+
+ if torch.xpu.is_available():
769+
+ kernel_options.setdefault(
770+
+ "BLOCK_M",
771+
+ (
772+
+ max(
773+
+ next_power_of_2(
774+
+ V.graph.sizevars.size_hint(
775+
+ seq_len_q,
776+
+ fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
777+
+ )
778+
+ * gqa_shared_heads
779+
+ ),
780+
+ 8,
781+
+ )
782+
+ ),
783+
+ )
784+
+ else:
785+
+ kernel_options.setdefault(
786+
+ "BLOCK_M",
787+
+ (
788+
+ # m
789+
+ # if V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 0))
790+
+ # else # Always use a BLOCK_M > 16 before Triton fix https://github.com/triton-lang/triton/pull/4061 is in pin
791+
+ max(
792+
+ next_power_of_2(
793+
+ V.graph.sizevars.size_hint(
794+
+ seq_len_q,
795+
+ fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
796+
+ )
797+
+ * gqa_shared_heads
798+
+ ),
799+
+ 16,
800+
+ )
801+
+ ),
802+
+ )
803+
804+
query = ir.ExternKernel.realize_input(query)
805+
stride_b, stride_hq, stride_seq_len_q, stride_qk_head_dim = query.get_stride()
746806
diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py
747-
index dfd37523a37027..b830ec6369a9d7 100644
807+
index 40a9645186792f..eaa6fbeaf0d4ea 100644
748808
--- a/torch/_inductor/template_heuristics.py
749809
+++ b/torch/_inductor/template_heuristics.py
750-
@@ -1178,3 +1178,87 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
810+
@@ -1185,3 +1185,87 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
751811
"""
752812
Placeholder child class for XPU specific overrides.
753813
"""
@@ -836,7 +896,7 @@ index dfd37523a37027..b830ec6369a9d7 100644
836896
+
837897
+ return flex_decode_configs
838898
diff --git a/torch/_ops.py b/torch/_ops.py
839-
index 337b9a11e6a180..5e3423285e02b5 100644
899+
index 600f6d9e1ada1c..1121ced7eaa5ff 100644
840900
--- a/torch/_ops.py
841901
+++ b/torch/_ops.py
842902
@@ -267,6 +267,7 @@ def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type]
@@ -848,10 +908,10 @@ index 337b9a11e6a180..5e3423285e02b5 100644
848908

849909

850910
diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py
851-
index 15a00e1a9d342b..3c9b3a20173997 100644
911+
index ce592c1ed342f8..bcc180184d9aa4 100644
852912
--- a/torch/nn/attention/flex_attention.py
853913
+++ b/torch/nn/attention/flex_attention.py
854-
@@ -1142,11 +1142,8 @@ def _validate_device(query: Tensor, key: Tensor, value: Tensor):
914+
@@ -1146,11 +1146,8 @@ def _validate_device(query: Tensor, key: Tensor, value: Tensor):
855915
"""TODO: Remove once non cuda/cpu devices support is added
856916
We only need to check query since we have already that q,k,v are on the same device
857917
"""

0 commit comments

Comments
 (0)