1
1
diff --git a/.ci/docker/common/install_xpu.sh b/.ci/docker/common/install_xpu.sh
2
- index 51e9df623d5d1..647e77f6d17bd 100644
2
+ index 51e9df623d5d1a..647e77f6d17bdc 100644
3
3
--- a/.ci/docker/common/install_xpu.sh
4
4
+++ b/.ci/docker/common/install_xpu.sh
5
5
@@ -35,12 +35,12 @@ function install_ubuntu() {
@@ -21,7 +21,7 @@ index 51e9df623d5d1..647e77f6d17bd 100644
21
21
apt-get install -y intel-ocloc
22
22
fi
23
23
diff --git a/.ci/docker/ubuntu-xpu/Dockerfile b/.ci/docker/ubuntu-xpu/Dockerfile
24
- index a0e7dce3df4d5..9cd30e0178bf9 100644
24
+ index a0e7dce3df4d55..9cd30e0178bf92 100644
25
25
--- a/.ci/docker/ubuntu-xpu/Dockerfile
26
26
+++ b/.ci/docker/ubuntu-xpu/Dockerfile
27
27
@@ -63,6 +63,7 @@ RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.t
@@ -33,7 +33,7 @@ index a0e7dce3df4d5..9cd30e0178bf9 100644
33
33
RUN bash ./install_xpu.sh && rm install_xpu.sh
34
34
35
35
diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py
36
- index e6b46ae8edc9d..f8c8d9f31cde3 100644
36
+ index 4d14555800c8c4..ab40d19d2ff5ee 100644
37
37
--- a/test/inductor/test_flex_attention.py
38
38
+++ b/test/inductor/test_flex_attention.py
39
39
@@ -41,20 +41,26 @@
@@ -525,7 +525,7 @@ index e6b46ae8edc9d..f8c8d9f31cde3 100644
525
525
if __name__ == "__main__":
526
526
from torch._inductor.test_case import run_tests
527
527
diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py
528
- index 2125e78eea8db..f59433c7e77f2 100644
528
+ index 3b4905fc356168..3a165e5fff2eda 100644
529
529
--- a/test/inductor/test_flex_decoding.py
530
530
+++ b/test/inductor/test_flex_decoding.py
531
531
@@ -27,6 +27,7 @@
@@ -691,7 +691,7 @@ index 2125e78eea8db..f59433c7e77f2 100644
691
691
if __name__ == "__main__":
692
692
from torch._inductor.test_case import run_tests
693
693
diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py
694
- index 99e869dc8fdb7..1426f000d191d 100644
694
+ index 99e869dc8fdb71..1426f000d191d2 100644
695
695
--- a/torch/_inductor/kernel/flex_attention.py
696
696
+++ b/torch/_inductor/kernel/flex_attention.py
697
697
@@ -1441,7 +1441,9 @@ def flex_attention(
@@ -717,7 +717,7 @@ index 99e869dc8fdb7..1426f000d191d 100644
717
717
# Default config for warp specialization
718
718
num_consumer_groups, num_buffers_warp_spec = 0, 0
719
719
diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py
720
- index 7e0aef9818560..343086e5b2d16 100644
720
+ index 7e0aef98185603..343086e5b2d16a 100644
721
721
--- a/torch/_inductor/kernel/flex_decoding.py
722
722
+++ b/torch/_inductor/kernel/flex_decoding.py
723
723
@@ -310,7 +310,10 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
@@ -744,7 +744,7 @@ index 7e0aef9818560..343086e5b2d16 100644
744
744
# TODO: fix autotuning.
745
745
746
746
diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py
747
- index dfd37523a3702..b830ec6369a9d 100644
747
+ index dfd37523a37027..b830ec6369a9d7 100644
748
748
--- a/torch/_inductor/template_heuristics.py
749
749
+++ b/torch/_inductor/template_heuristics.py
750
750
@@ -1178,3 +1178,87 @@ class XPUConfigHeuristic(BaseConfigHeuristic):
@@ -836,7 +836,7 @@ index dfd37523a3702..b830ec6369a9d 100644
836
836
+
837
837
+ return flex_decode_configs
838
838
diff --git a/torch/_ops.py b/torch/_ops.py
839
- index 337b9a11e6a18..5e3423285e02b 100644
839
+ index 337b9a11e6a180..5e3423285e02b5 100644
840
840
--- a/torch/_ops.py
841
841
+++ b/torch/_ops.py
842
842
@@ -267,6 +267,7 @@ def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type]
@@ -848,7 +848,7 @@ index 337b9a11e6a18..5e3423285e02b 100644
848
848
849
849
850
850
diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py
851
- index 15a00e1a9d342..3c9b3a2017399 100644
851
+ index 15a00e1a9d342b..3c9b3a20173997 100644
852
852
--- a/torch/nn/attention/flex_attention.py
853
853
+++ b/torch/nn/attention/flex_attention.py
854
854
@@ -1142,11 +1142,8 @@ def _validate_device(query: Tensor, key: Tensor, value: Tensor):
@@ -866,7 +866,7 @@ index 15a00e1a9d342..3c9b3a2017399 100644
866
866
"FlexAttention is only supported on CUDA, CPU or HPU devices. "
867
867
f"Found input tensors on {query.device.type} device."
868
868
diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py
869
- index 01499280da8f5..6a5951fde65dc 100644
869
+ index 01499280da8f5d..6a5951fde65dca 100644
870
870
--- a/torch/testing/_internal/common_device_type.py
871
871
+++ b/torch/testing/_internal/common_device_type.py
872
872
@@ -1342,8 +1342,8 @@ def dep_fn(self, *args, **kwargs):
0 commit comments