@@ -46,10 +46,10 @@ index 4d7a85029e3..f3d45ea5520 100644
4646
4747 @requires_gpu
4848diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py
49- index ace56135fe1..7e925dd6e45 100644
49+ index c3f72bc5215..03aab72dca9 100644
5050--- a/torch/_higher_order_ops/triton_kernel_wrap.py
5151+++ b/torch/_higher_order_ops/triton_kernel_wrap.py
52- @@ -238 ,7 +238 ,7 @@ def generate_ttir(
52+ @@ -239 ,7 +239 ,7 @@ def generate_ttir(
5353
5454 target = triton.runtime.driver.active.get_current_target()
5555 backend = triton.compiler.compiler.make_backend(target)
@@ -58,17 +58,20 @@ index ace56135fe1..7e925dd6e45 100644
5858 except ImportError:
5959 return kernel._get_config(*args)
6060
61- @@ -247,7 +247,8 @@ def generate_ttir(
61+ @@ -248,9 +248,10 @@ def generate_ttir(
6262 name: arg for name, arg in ordered_args.items() if not isinstance(arg, Tensor)
6363 }
6464
6565- # Build kernel signature -- doesn't include constexpr arguments.
6666+ # Build kernel signature; it should also include `constexpr` arguments but `kernel._key_of`
6767+ # doesn't work correctly with them. They will be added in `fixup_signature` function later.
6868 signature = {
69- name: kernel._type_of(kernel._key_of(arg))
69+ - name: kernel._type_of(kernel._key_of(arg))
70+ + name: triton.runtime.jit.mangle_type(arg)
7071 for i, (name, arg) in enumerate(ordered_args.items())
71- @@ -257,7 +258,18 @@ def generate_ttir(
72+ if i not in kernel.constexprs
73+ }
74+ @@ -258,7 +259,18 @@ def generate_ttir(
7275 triton._C.libtriton.ir.load_dialects(context)
7376 backend.load_dialects(context)
7477
@@ -135,12 +138,12 @@ index 276c01f3f42..5c633b7963b 100644
135138
136139 # Instantiate AttrsDescriptor with the prepared arguments
137140diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py
138- index af8530e94d0..1ec44de9806 100644
141+ index 281d0e78ba4..901263df4aa 100644
139142--- a/torch/_inductor/runtime/triton_heuristics.py
140143+++ b/torch/_inductor/runtime/triton_heuristics.py
141- @@ -435,11 +435,22 @@ class CachingAutotuner(KernelInterface):
142- else :
143- triton_helpers.set_driver_to_gpu( )
144+ @@ -414,10 +414,21 @@ class CachingAutotuner(KernelInterface):
145+ if not ASTSource :
146+ raise RuntimeError("Installed triton version too old, please upgrade" )
144147
145148+ def fixup_signature(arg_names, signature, constants):
146149+ new_signature = {arg_name: None for arg_name in arg_names}
@@ -153,12 +156,11 @@ index af8530e94d0..1ec44de9806 100644
153156+ new_signature[arg_name] = signature[arg_name]
154157+ return new_signature
155158+
156- if ASTSource:
157- compile_args = (
158- ASTSource(
159- self.fn,
160- - compile_meta["signature"],
161- + fixup_signature(self.fn.arg_names, compile_meta["signature"], compile_meta["constants"]),
162- compile_meta["constants"],
163- compile_meta["configs"][0],
164- ),
159+ compile_args = (
160+ ASTSource(
161+ self.fn,
162+ - compile_meta["signature"],
163+ + fixup_signature(self.fn.arg_names, compile_meta["signature"], compile_meta["constants"]),
164+ compile_meta["constants"],
165+ compile_meta["configs"][0],
166+ ),
0 commit comments