Skip to content

Commit 909395a

Browse files
authored
Merge branch 'main' into amyachev/issue3087
2 parents 2ed752e + 36aa1cc commit 909395a

File tree

6 files changed

+29
-21
lines changed

6 files changed

+29
-21
lines changed

.github/pins/pytorch.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
61dc5e9c0a36d590adc47b4110efd94d9eb59306
1+
1e881ceecfe80532206ca4e0acb64391fab8b935

python/test/unit/language/test_compile_errors.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,9 @@ def test_min_dot_size(dtype):
406406
pytest.skip("fp16 FMA path supports all sizes")
407407
else:
408408
error_msg = "M >= 16, N >= 16 and K >= 16"
409+
elif is_xpu():
410+
# XPU supports all sizes
411+
pass
409412
else:
410413
pytest.skip("Test only supported on CUDA and HIP")
411414

python/triton/runtime/build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compi
103103
if os.getenv("VERBOSE"):
104104
print(" ".join(cc_cmd))
105105

106-
ret = subprocess.check_call(cc_cmd)
106+
ret = subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL)
107107
if ret == 0:
108108
return so
109109
# extra arguments

scripts/patch-pytorch.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@ echo "Applying PyTorch patches in $REPO_ROOT"
1717
cd "$REPO_ROOT"
1818

1919
curl -sSL https://github.com/pytorch/pytorch/pull/126516.diff | git apply -
20+
2021
git apply "${SCRIPT_DIR}/pytorch.patch"

scripts/pytorch.patch

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ index 4d7a85029e3..f3d45ea5520 100644
4646

4747
@requires_gpu
4848
diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py
49-
index ace56135fe1..7e925dd6e45 100644
49+
index c3f72bc5215..03aab72dca9 100644
5050
--- a/torch/_higher_order_ops/triton_kernel_wrap.py
5151
+++ b/torch/_higher_order_ops/triton_kernel_wrap.py
52-
@@ -238,7 +238,7 @@ def generate_ttir(
52+
@@ -239,7 +239,7 @@ def generate_ttir(
5353

5454
target = triton.runtime.driver.active.get_current_target()
5555
backend = triton.compiler.compiler.make_backend(target)
@@ -58,17 +58,20 @@ index ace56135fe1..7e925dd6e45 100644
5858
except ImportError:
5959
return kernel._get_config(*args)
6060

61-
@@ -247,7 +247,8 @@ def generate_ttir(
61+
@@ -248,9 +248,10 @@ def generate_ttir(
6262
name: arg for name, arg in ordered_args.items() if not isinstance(arg, Tensor)
6363
}
6464

6565
- # Build kernel signature -- doesn't include constexpr arguments.
6666
+ # Build kernel signature; it should also include `constexpr` arguments but `kernel._key_of`
6767
+ # doesn't work correctly with them. They will be added in `fixup_signature` function later.
6868
signature = {
69-
name: kernel._type_of(kernel._key_of(arg))
69+
- name: kernel._type_of(kernel._key_of(arg))
70+
+ name: triton.runtime.jit.mangle_type(arg)
7071
for i, (name, arg) in enumerate(ordered_args.items())
71-
@@ -257,7 +258,18 @@ def generate_ttir(
72+
if i not in kernel.constexprs
73+
}
74+
@@ -258,7 +259,18 @@ def generate_ttir(
7275
triton._C.libtriton.ir.load_dialects(context)
7376
backend.load_dialects(context)
7477

@@ -135,12 +138,12 @@ index 276c01f3f42..5c633b7963b 100644
135138

136139
# Instantiate AttrsDescriptor with the prepared arguments
137140
diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py
138-
index af8530e94d0..1ec44de9806 100644
141+
index 281d0e78ba4..901263df4aa 100644
139142
--- a/torch/_inductor/runtime/triton_heuristics.py
140143
+++ b/torch/_inductor/runtime/triton_heuristics.py
141-
@@ -435,11 +435,22 @@ class CachingAutotuner(KernelInterface):
142-
else:
143-
triton_helpers.set_driver_to_gpu()
144+
@@ -414,10 +414,21 @@ class CachingAutotuner(KernelInterface):
145+
if not ASTSource:
146+
raise RuntimeError("Installed triton version too old, please upgrade")
144147

145148
+ def fixup_signature(arg_names, signature, constants):
146149
+ new_signature = {arg_name: None for arg_name in arg_names}
@@ -153,12 +156,11 @@ index af8530e94d0..1ec44de9806 100644
153156
+ new_signature[arg_name] = signature[arg_name]
154157
+ return new_signature
155158
+
156-
if ASTSource:
157-
compile_args = (
158-
ASTSource(
159-
self.fn,
160-
- compile_meta["signature"],
161-
+ fixup_signature(self.fn.arg_names, compile_meta["signature"], compile_meta["constants"]),
162-
compile_meta["constants"],
163-
compile_meta["configs"][0],
164-
),
159+
compile_args = (
160+
ASTSource(
161+
self.fn,
162+
- compile_meta["signature"],
163+
+ fixup_signature(self.fn.arg_names, compile_meta["signature"], compile_meta["constants"]),
164+
compile_meta["constants"],
165+
compile_meta["configs"][0],
166+
),

third_party/intel/backend/arch_parser.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include <iostream>
10+
911
#include <sycl/sycl.hpp>
1012

1113
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
@@ -31,7 +33,7 @@ static PyObject *parseDeviceArch(PyObject *self, PyObject *args) {
3133
arch = "lnl";
3234
break;
3335
default:
34-
printf("sycl_arch = %d", sycl_arch);
36+
std::cerr << "sycl_arch not recognized: " << (int)sycl_arch << std::endl;
3537
}
3638

3739
return Py_BuildValue("s", arch.c_str());

0 commit comments

Comments
 (0)