Skip to content

Commit 4e46df3

Browse files
authored
Update pytorch pin and fix issues related to new Triton interface (#3097)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 8f1f00e commit 4e46df3

File tree

3 files changed

+22
-19
lines changed

3 files changed

+22
-19
lines changed

.github/pins/pytorch.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
61dc5e9c0a36d590adc47b4110efd94d9eb59306
1+
1e881ceecfe80532206ca4e0acb64391fab8b935

scripts/patch-pytorch.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@ echo "Applying PyTorch patches in $REPO_ROOT"
1717
cd "$REPO_ROOT"
1818

1919
curl -sSL https://github.com/pytorch/pytorch/pull/126516.diff | git apply -
20+
2021
git apply "${SCRIPT_DIR}/pytorch.patch"

scripts/pytorch.patch

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ index 4d7a85029e3..f3d45ea5520 100644
4646

4747
@requires_gpu
4848
diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py
49-
index ace56135fe1..7e925dd6e45 100644
49+
index c3f72bc5215..03aab72dca9 100644
5050
--- a/torch/_higher_order_ops/triton_kernel_wrap.py
5151
+++ b/torch/_higher_order_ops/triton_kernel_wrap.py
52-
@@ -238,7 +238,7 @@ def generate_ttir(
52+
@@ -239,7 +239,7 @@ def generate_ttir(
5353

5454
target = triton.runtime.driver.active.get_current_target()
5555
backend = triton.compiler.compiler.make_backend(target)
@@ -58,17 +58,20 @@ index ace56135fe1..7e925dd6e45 100644
5858
except ImportError:
5959
return kernel._get_config(*args)
6060

61-
@@ -247,7 +247,8 @@ def generate_ttir(
61+
@@ -248,9 +248,10 @@ def generate_ttir(
6262
name: arg for name, arg in ordered_args.items() if not isinstance(arg, Tensor)
6363
}
6464

6565
- # Build kernel signature -- doesn't include constexpr arguments.
6666
+ # Build kernel signature; it should also include `constexpr` arguments but `kernel._key_of`
6767
+ # doesn't work correctly with them. They will be added in `fixup_signature` function later.
6868
signature = {
69-
name: kernel._type_of(kernel._key_of(arg))
69+
- name: kernel._type_of(kernel._key_of(arg))
70+
+ name: triton.runtime.jit.mangle_type(arg)
7071
for i, (name, arg) in enumerate(ordered_args.items())
71-
@@ -257,7 +258,18 @@ def generate_ttir(
72+
if i not in kernel.constexprs
73+
}
74+
@@ -258,7 +259,18 @@ def generate_ttir(
7275
triton._C.libtriton.ir.load_dialects(context)
7376
backend.load_dialects(context)
7477

@@ -135,12 +138,12 @@ index 276c01f3f42..5c633b7963b 100644
135138

136139
# Instantiate AttrsDescriptor with the prepared arguments
137140
diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py
138-
index af8530e94d0..1ec44de9806 100644
141+
index 281d0e78ba4..901263df4aa 100644
139142
--- a/torch/_inductor/runtime/triton_heuristics.py
140143
+++ b/torch/_inductor/runtime/triton_heuristics.py
141-
@@ -435,11 +435,22 @@ class CachingAutotuner(KernelInterface):
142-
else:
143-
triton_helpers.set_driver_to_gpu()
144+
@@ -414,10 +414,21 @@ class CachingAutotuner(KernelInterface):
145+
if not ASTSource:
146+
raise RuntimeError("Installed triton version too old, please upgrade")
144147

145148
+ def fixup_signature(arg_names, signature, constants):
146149
+ new_signature = {arg_name: None for arg_name in arg_names}
@@ -153,12 +156,11 @@ index af8530e94d0..1ec44de9806 100644
153156
+ new_signature[arg_name] = signature[arg_name]
154157
+ return new_signature
155158
+
156-
if ASTSource:
157-
compile_args = (
158-
ASTSource(
159-
self.fn,
160-
- compile_meta["signature"],
161-
+ fixup_signature(self.fn.arg_names, compile_meta["signature"], compile_meta["constants"]),
162-
compile_meta["constants"],
163-
compile_meta["configs"][0],
164-
),
159+
compile_args = (
160+
ASTSource(
161+
self.fn,
162+
- compile_meta["signature"],
163+
+ fixup_signature(self.fn.arg_names, compile_meta["signature"], compile_meta["constants"]),
164+
compile_meta["constants"],
165+
compile_meta["configs"][0],
166+
),

0 commit comments

Comments
 (0)