11diff --git a/test/inductor/test_codegen_triton.py b/test/inductor/test_codegen_triton.py
2- index 84264bf1b01..aa9a624d5ac 100644
2+ index 84264bf1b0..54b6d028cf 100644
33--- a/test/inductor/test_codegen_triton.py
44+++ b/test/inductor/test_codegen_triton.py
5- @@ -48,7 +48,7 @@ class TestCodegenTriton(InductorTestCase):
6- return config.divisible_by_16
5+ @@ -39,45 +39,31 @@ class TestCodegenTriton(InductorTestCase):
6+ s0 = sympy.Symbol("s0", positive=True, integer=True)
7+ s1 = sympy.Symbol("s1", positive=True, integer=True)
78
8- self.assertEqual(
9+ - def _check_divisibility(config):
10+ - try:
11+ - from triton.backends.compiler import AttrsDescriptor # noqa: F401
12+ -
13+ - return config.divisibility_16
14+ - except ImportError:
15+ - return config.divisible_by_16
16+ -
17+ - self.assertEqual(
918- (2,),
10- + [(2,)],
11- _check_divisibility(
12- triton_utils.config_of(
13- [
14- @@ -63,7 +63,7 @@ class TestCodegenTriton(InductorTestCase):
19+ - _check_divisibility(
20+ - triton_utils.config_of(
21+ - [
22+ - SizeArg("A", two), # no
23+ - SizeArg("B", eight), # no
24+ - SizeArg("C", sixteen), # yes
25+ - SizeArg("D", s0), # no
26+ - SizeArg("E", s1), # no
27+ - ]
28+ - )
29+ - ),
30+ + config = triton_utils.config_of(
31+ + [
32+ + SizeArg("A", two), # no
33+ + SizeArg("B", eight), # no
34+ + SizeArg("C", sixteen), # yes
35+ + SizeArg("D", s0), # no
36+ + SizeArg("E", s1), # no
37+ + ]
1538 )
16-
17- self.assertEqual(
39+ -
40+ - self.assertEqual(
1841- (0, 2, 4, 5, 6),
19- + [(0,), (2,), (4,), (5,), (6,)],
20- _check_divisibility(
21- triton_utils.config_of(
22- [
42+ - _check_divisibility(
43+ - triton_utils.config_of(
44+ - [
45+ - SizeArg("A", two * eight), # 0: yes
46+ - SizeArg("B", eight * s0), # 1: no
47+ - SizeArg("C", two * eight * s0), # 2: yes
48+ - SizeArg("D", s0 * s1), # 3: no
49+ - SizeArg("E", sixteen * s0), # 4: yes
50+ - SizeArg("F", sixteen * eight * s0 * s1), # 5: yes
51+ - SizeArg("G", two * eight * s0 * s1), # 6: yes
52+ - ]
53+ - )
54+ - ),
55+ + # check for key
56+ + config[((2,),)]
57+ +
58+ + config = triton_utils.config_of(
59+ + [
60+ + SizeArg("A", two * eight), # 0: yes
61+ + SizeArg("B", eight * s0), # 1: no
62+ + SizeArg("C", two * eight * s0), # 2: yes
63+ + SizeArg("D", s0 * s1), # 3: no
64+ + SizeArg("E", sixteen * s0), # 4: yes
65+ + SizeArg("F", sixteen * eight * s0 * s1), # 5: yes
66+ + SizeArg("G", two * eight * s0 * s1), # 6: yes
67+ + ]
68+ )
69+ + # check for key
70+ + config[((0,), (2,), (4,), (5,), (6,))]
71+
72+
73+ if __name__ == "__main__":
2374diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py
24- index 4d7a85029e3..f3d45ea5520 100644
75+ index f674262ab6..fc2b8a6c27 100644
2576--- a/test/inductor/test_triton_kernels.py
2677+++ b/test/inductor/test_triton_kernels.py
27- @@ -1268,9 +1268,9 @@ def forward(self, x_1, output_1):
28- if dynamic:
29- # when half_n_elements passed to the Triton kernel is
30- # dynamic, equal_to_1 specializaiton can't be enforced
78+ @@ -55,14 +55,6 @@ if HAS_GPU:
79+ fast_dividef as my_fast_dividef,
80+ )
81+
82+ - def _triton_get_ast_equal_to_str(params):
83+ - try:
84+ - from triton.backends.compiler import AttrsDescriptor # noqa: F401
85+ -
86+ - return f"'tt.equal_to': {params}"
87+ - except ImportError:
88+ - return f"equal_to_1={params}"
89+ -
90+ # Define shared triton constants here.
91+ CONSTANT_C: tl.constexpr = 4
92+ STRING_CONSTANT_C: tl.constexpr = "CONSTANT_C"
93+ @@ -1266,12 +1258,6 @@ def forward(self, x_1, output_1):
94+ torch.compile(f, dynamic=dynamic), x, y
95+ )
96+
97+ - if dynamic:
98+ - # when half_n_elements passed to the Triton kernel is
99+ - # dynamic, equal_to_1 specializaiton can't be enforced
31100- self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0])
32- + self.assertTrue(_triton_get_ast_equal_to_str([]) in sources[0])
33- else:
101+ - else:
34102- self.assertTrue(_triton_get_ast_equal_to_str((3,)) in sources[0])
35- + self.assertTrue(_triton_get_ast_equal_to_str([(3,)]) in sources[0])
36103 self.assertEqual(compiled_out, eager_out)
37104
38105 @requires_gpu
39- @@ -1299,7 +1299,7 @@ def forward(self, x_1, output_1):
106+ @@ -1298,9 +1284,6 @@ def forward(self, x_1, output_1):
107+ torch.compile(f, dynamic=dynamic), x, y
108+ )
40109
41- # float 1.0 (both literal or symbolic)
42- # should not be added to equal_to_1
110+ - # float 1.0 (both literal or symbolic)
111+ - # should not be added to equal_to_1
43112- self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0])
44- + self.assertTrue(_triton_get_ast_equal_to_str([]) in sources[0])
45113 self.assertEqual(compiled_out, eager_out)
46114
47115 @requires_gpu
48116diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py
49- index c3f72bc5215..03aab72dca9 100644
117+ index c3f72bc521..016e30790a 100644
50118--- a/torch/_higher_order_ops/triton_kernel_wrap.py
51119+++ b/torch/_higher_order_ops/triton_kernel_wrap.py
52- @@ -239,7 +239,7 @@ def generate_ttir(
120+ @@ -184,6 +184,7 @@ def generate_ttir(
121+ from triton.compiler.compiler import ASTSource
122+ from triton.runtime.autotuner import Autotuner
123+ from triton.runtime.jit import JITFunction
124+ + from triton._utils import find_paths_if, get_iterable_path
125+
126+ import torch._inductor.ir
127+ from torch._subclasses.fake_tensor import FakeTensor
128+ @@ -233,24 +234,40 @@ def generate_ttir(
129+ name for name, arg in ordered_args.items() if isinstance(arg, Tensor)
130+ ]
131+
132+ - def _get_specialization(args): # type: ignore[no-untyped-def]
133+ + def _get_specialization(kernel, *args): # type: ignore[no-untyped-def]
134+ try:
135+ - from triton.backends.compiler import AttrsDescriptor # noqa: F401
136+ + from triton.runtime.jit import create_function_from_signature
53137
54138 target = triton.runtime.driver.active.get_current_target()
55139 backend = triton.compiler.compiler.make_backend(target)
56140- return backend.get_attrs_descriptor(args, kernel.params)
57- + return backend.get_attrs_descriptor(kernel.params, args)
141+ + # from: binder = create_function_from_signature(self.signature, self.params, backend)
142+ + specialization = []
143+ + # signature
144+ + for name, kp in zip(kernel.signature.parameters.keys(), kernel.params):
145+ + if kp.is_constexpr:
146+ + specialization.append(f'("constexpr", {name})')
147+ + else:
148+ + is_const = 'True' if kp.is_const else 'False'
149+ + specialize = 'False' if kp.do_not_specialize else 'True'
150+ + align = 'False' if kp.do_not_specialize_on_alignment else 'True'
151+ + ret = f"specialize_impl({name}, specialize_extra, {is_const}, {specialize}, {align})"
152+ + if kp.annotation_type:
153+ + specialization.append(f'("{kp.annotation_type}",) + {ret}[1:]')
154+ + else:
155+ + specialization.append(f"{ret}")
156+ + return specialization
58157 except ImportError:
59158 return kernel._get_config(*args)
60159
61- @@ -248,9 +248,10 @@ def generate_ttir(
160+ - specialization = _get_specialization(ordered_args.values())
161+ + specialization = _get_specialization(kernel, ordered_args.values())
162+ constants = {
62163 name: arg for name, arg in ordered_args.items() if not isinstance(arg, Tensor)
63164 }
64165
@@ -71,7 +172,7 @@ index c3f72bc5215..03aab72dca9 100644
71172 for i, (name, arg) in enumerate(ordered_args.items())
72173 if i not in kernel.constexprs
73174 }
74- @@ -258,7 +259,18 @@ def generate_ttir(
175+ @@ -258,7 +275,22 @@ def generate_ttir(
75176 triton._C.libtriton.ir.load_dialects(context)
76177 backend.load_dialects(context)
77178
@@ -87,10 +188,105 @@ index c3f72bc5215..03aab72dca9 100644
87188+ new_signature[arg_name] = signature[arg_name]
88189+ return new_signature
89190+
90- + src = ASTSource(kernel, fixup_signature(kernel.arg_names, signature, constants), constants, specialization)
191+ + attrvals = [x[1] for x in specialization]
192+ + from triton._utils import find_paths_if, get_iterable_path
193+ + attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
194+ + attrs = {k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs}
195+ + src = ASTSource(kernel, fixup_signature(kernel.arg_names, signature, constants), constants, attrs)
91196
92197 # Triton changes ASTSource.make_ir to take 3/4 arguments. Handle
93198 # backward compatibility here.
199+ diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
200+ index a5fe2d1119..5ea0018957 100644
201+ --- a/torch/_inductor/ir.py
202+ +++ b/torch/_inductor/ir.py
203+ @@ -5743,52 +5743,6 @@ class UserDefinedTritonKernel(ExternKernel):
204+ for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel):
205+ if kernel.arg_names.index(kwarg) in kernel.constexprs:
206+ constexpr_indices.append(idx)
207+ - """
208+ - Filter out None args.
209+ -
210+ - see https://github.com/pytorch/pytorch/issues/115344
211+ -
212+ - Two cases for a None arg:
213+ - 1. The arg is already tl.constexpr, so leave it in
214+ - 2. The arg is not tl.constexpr so we have to remove it
215+ - """
216+ - constexpr_indices_set = OrderedSet(constexpr_indices)
217+ - REMOVED = object()
218+ - raw_args = [
219+ - (
220+ - (idx, arg)
221+ - if (arg is not None) or (arg is None and idx in constexpr_indices_set)
222+ - else (idx, REMOVED)
223+ - )
224+ - for idx, arg in enumerate(raw_args)
225+ - ]
226+ - removed_none_args = [idx for idx, val in raw_args if val == REMOVED]
227+ - raw_args = [val for idx, val in raw_args if val != REMOVED]
228+ -
229+ - # We have to compute the constexpr indices for the new, filtered raw_args
230+ - # We also have to adjust equal_to_1.
231+ - if removed_none_args:
232+ - eq1_indices_set = OrderedSet[int](triton_meta["configs"][0].equal_to_1)
233+ - constexpr_indices = []
234+ - equal_to_1 = []
235+ - index_shift = 0
236+ - for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel):
237+ - # every time we encounter an idx we removed, adjust by one to account for it
238+ - # So for example if we had [None, const X]
239+ - # iter 1:
240+ - # None was removed, adjust=1
241+ - # iter 2:
242+ - # X is const at idx=1, but the adjusted idx is 0 now, because None was removed
243+ - if idx in removed_none_args:
244+ - index_shift += 1
245+ - continue
246+ - arg_index = kernel.arg_names.index(kwarg)
247+ - if arg_index in kernel.constexprs:
248+ - constexpr_indices.append(idx - index_shift)
249+ - if arg_index in eq1_indices_set:
250+ - equal_to_1.append(idx - index_shift)
251+ -
252+ - triton_meta["configs"][0].equal_to_1 = equal_to_1
253+
254+ # Call to kernel
255+ self.codegen_comment(wrapper)
256+ diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py
257+ index ccd69bf828..2c89659132 100644
258+ --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py
259+ +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py
260+ @@ -551,28 +551,8 @@ class CppWrapperGpu(CppWrapperCpu):
261+ )
262+ kernel_var_name = self.generate_load_kernel_once(kernel_name, V.graph)
263+
264+ - # args with value 1 are added into equal_to_1 and constants
265+ - # in triton_meta (in the Python codegen) which makes them
266+ - # inlined in the PTX and compiled CUBIN
267+ - arg_signatures = []
268+ - if (
269+ - triton_meta is not None
270+ - and triton_meta.get("configs")
271+ - and triton_meta.get("signature")
272+ - ):
273+ - equal_to_1 = triton_meta["configs"][0].equal_to_1
274+ - call_args = [
275+ - arg for i, arg in enumerate(call_args) if i not in equal_to_1
276+ - ]
277+ - arg_types = [t for i, t in enumerate(arg_types) if i not in equal_to_1]
278+ - # extract the arg signatures from triton_meta
279+ - arg_signatures = triton_meta["signature"].values()
280+ - arg_signatures = [
281+ - v for i, v in enumerate(arg_signatures) if i not in equal_to_1
282+ - ]
283+ -
284+ call_args_str = self.generate_args_decl(
285+ - call_args, arg_types, arg_signatures
286+ + call_args, arg_types, list(triton_meta["signature"].values())
287+ )
288+ kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
289+ self.writeline(f"void* {kernel_args_var}[] = {{{call_args_str}}};")
94290diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
95291index 00031a56b8d..59086d41b40 100644
96292--- a/torch/_inductor/codegen/triton.py
@@ -173,7 +369,7 @@ index fa2a1334380..4d730fd45de 100644
173369 except ImportError:
174370 from triton.compiler.compiler import AttrsDescriptor
175371diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py
176- index 281d0e78ba4..3b059a365c9 100644
372+ index 281d0e78ba..8b3857bf2e 100644
177373--- a/torch/_inductor/runtime/triton_heuristics.py
178374+++ b/torch/_inductor/runtime/triton_heuristics.py
179375@@ -414,10 +414,21 @@ class CachingAutotuner(KernelInterface):
@@ -213,19 +409,40 @@ index 281d0e78ba4..3b059a365c9 100644
213409 ]
214410 binary_shared = (
215411 binary.shared if hasattr(binary, "shared") else binary.metadata.shared
216- @@ -952,6 +961,7 @@ class CachingAutotuner(KernelInterface):
412+ @@ -646,8 +655,11 @@ class CachingAutotuner(KernelInterface):
413+ )
414+ # reset to zero before evaluating any config
415+ self.reset_to_zero_args(*args, **kwargs)
416+ + new_cloned_args = [*cloned_args]
417+ + for arg_name, arg_value in launcher.config.kwargs.items():
418+ + new_cloned_args.insert(self.fn.arg_names.index(arg_name), arg_value)
419+ launcher(
420+ - *cloned_args,
421+ + *new_cloned_args,
422+ **cloned_kwargs,
423+ grid=grid,
424+ stream=stream,
425+ @@ -950,15 +962,21 @@ class CachingAutotuner(KernelInterface):
426+ "stream": stream,
427+ },
217428 ):
429+ + new_args = [*args]
430+ + for arg_name, arg_value in launcher.config.kwargs.items():
431+ + new_args.insert(self.fn.arg_names.index(arg_name), arg_value)
218432 return launcher(
219- *args,
220- + **launcher.config.kwargs ,
433+ - *args,
434+ + *new_args ,
221435 **kwargs,
222436 grid=grid,
223437 stream=stream,
224- @@ -959,6 +969,7 @@ class CachingAutotuner(KernelInterface):
438+ )
225439 else:
440+ + new_args = [*args]
441+ + for arg_name, arg_value in launcher.config.kwargs.items():
442+ + new_args.insert(self.fn.arg_names.index(arg_name), arg_value)
226443 return launcher(
227- *args,
228- + **launcher.config.kwargs ,
444+ - *args,
445+ + *new_args ,
229446 **kwargs,
230447 grid=grid,
231448 stream=stream,
0 commit comments