11diff --git a/test/inductor/test_codegen_triton.py b/test/inductor/test_codegen_triton.py
2- index 84264bf1b0..eb3fac0f39 100644
2+ index 84264bf1b01..aa9a624d5ac 100644
33--- a/test/inductor/test_codegen_triton.py
44+++ b/test/inductor/test_codegen_triton.py
55@@ -48,7 +48,7 @@ class TestCodegenTriton(InductorTestCase):
@@ -16,12 +16,37 @@ index 84264bf1b0..eb3fac0f39 100644
1616
1717 self.assertEqual(
1818- (0, 2, 4, 5, 6),
19- + [(0, 2, 4, 5, 6 )],
19+ + [(0,), (2,), (4,), (5,), (6, )],
2020 _check_divisibility(
2121 triton_utils.config_of(
2222 [
23+ diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py
24+ index 4d7a85029e3..f3d45ea5520 100644
25+ --- a/test/inductor/test_triton_kernels.py
26+ +++ b/test/inductor/test_triton_kernels.py
27+ @@ -1268,9 +1268,9 @@ def forward(self, x_1, output_1):
28+ if dynamic:
29+ # when half_n_elements passed to the Triton kernel is
30+ # dynamic, equal_to_1 specializaiton can't be enforced
31+ - self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0])
32+ + self.assertTrue(_triton_get_ast_equal_to_str([]) in sources[0])
33+ else:
34+ - self.assertTrue(_triton_get_ast_equal_to_str((3,)) in sources[0])
35+ + self.assertTrue(_triton_get_ast_equal_to_str([(3,)]) in sources[0])
36+ self.assertEqual(compiled_out, eager_out)
37+
38+ @requires_gpu
39+ @@ -1299,7 +1299,7 @@ def forward(self, x_1, output_1):
40+
41+ # float 1.0 (both literal or symbolic)
42+ # should not be added to equal_to_1
43+ - self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0])
44+ + self.assertTrue(_triton_get_ast_equal_to_str([]) in sources[0])
45+ self.assertEqual(compiled_out, eager_out)
46+
47+ @requires_gpu
2348diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py
24- index ace56135fe..568cbde0a1 100644
49+ index ace56135fe1..7e925dd6e45 100644
2550--- a/torch/_higher_order_ops/triton_kernel_wrap.py
2651+++ b/torch/_higher_order_ops/triton_kernel_wrap.py
2752@@ -238,7 +238,7 @@ def generate_ttir(
@@ -33,43 +58,41 @@ index ace56135fe..568cbde0a1 100644
3358 except ImportError:
3459 return kernel._get_config(*args)
3560
36- @@ -251,7 +251,6 @@ def generate_ttir(
61+ @@ -247,7 +247,8 @@ def generate_ttir(
62+ name: arg for name, arg in ordered_args.items() if not isinstance(arg, Tensor)
63+ }
64+
65+ - # Build kernel signature -- doesn't include constexpr arguments.
66+ + # Build kernel signature; it should also include `constexpr` arguments but `kernel._key_of`
67+ + # doesn't work correctly with them. They will be added in `fixup_signature` function later.
3768 signature = {
3869 name: kernel._type_of(kernel._key_of(arg))
3970 for i, (name, arg) in enumerate(ordered_args.items())
40- - if i not in kernel.constexprs
41- }
42-
71+ @@ -257,7 +258,18 @@ def generate_ttir(
4372 triton._C.libtriton.ir.load_dialects(context)
73+ backend.load_dialects(context)
74+
75+ - src = ASTSource(kernel, signature, constants, specialization)
76+ + def fixup_signature(arg_names, signature, constants):
77+ + new_signature = {arg_name: None for arg_name in arg_names}
78+ + for arg_name in arg_names:
79+ + if arg_name in constants and arg_name not in signature:
80+ + # If it's not in the signature already, it's a constexpr
81+ + # argument that we need to add in
82+ + new_signature[arg_name] = "constexpr"
83+ + else:
84+ + new_signature[arg_name] = signature[arg_name]
85+ + return new_signature
86+ +
87+ + src = ASTSource(kernel, fixup_signature(kernel.arg_names, signature, constants), constants, specialization)
88+
89+ # Triton changes ASTSource.make_ir to take 3/4 arguments. Handle
90+ # backward compatibility here.
4491diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
45- index 00031a56b8..b941e2aaa6 100644
92+ index 00031a56b8d..59086d41b40 100644
4693--- a/torch/_inductor/codegen/triton.py
4794+++ b/torch/_inductor/codegen/triton.py
48- @@ -2980,6 +2980,7 @@ class TritonKernel(SIMDKernel):
49- code.splice(self.imports_for_benchmark_kernel())
50-
51- argdefs, _, signature, _ = self.args.python_argdefs()
52- + # breakpoint()
53- # maps actual expression to SizeArg if it is in sizevars replacements
54- for i, arg in enumerate(signature):
55- if isinstance(arg, SizeArg):
56- @@ -3030,7 +3031,7 @@ class TritonKernel(SIMDKernel):
57- triton_meta = {
58- "signature": triton_meta_signature,
59- "device": DeviceProperties.create(V.graph.get_current_device_or_throw()),
60- - "constants": {},
61- + "constexprs": {},
62- }
63-
64- # Skip memory optimization for forward of the training loop where we expect
65- @@ -3065,20 +3066,12 @@ class TritonKernel(SIMDKernel):
66- argdefs.append(f"{tree.prefix}numel")
67- # constexpr version causes issues, see
68- # https://github.com/pytorch/torchdynamo/pull/1362
69- - # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint(
70- + # triton_meta["constexprs"][len(argdefs)] = V.graph.sizevars.size_hint(
71- # tree.numel
72- # )
95+ @@ -3071,14 +3071,6 @@ class TritonKernel(SIMDKernel):
7396 # argdefs.append(f"{tree.prefix}numel: tl.constexpr")
7497 triton_meta["configs"] = [config_of(signature)]
7598
@@ -84,175 +107,58 @@ index 00031a56b8..b941e2aaa6 100644
84107 self.triton_meta = triton_meta
85108
86109 for tree in self.range_trees:
87- @@ -3087,9 +3080,14 @@ class TritonKernel(SIMDKernel):
88- continue
89- if tree.tensor_dim is None:
90- continue
91- - argdefs.append(f"{tree.prefix.upper()}BLOCK : tl.constexpr")
92- + const_name = f"{tree.prefix.upper()}BLOCK"
93- + triton_meta['signature'][const_name] = 'constexpr'
94- + triton_meta['constexprs'][const_name] = tree.numel
95- + argdefs.append(f"{const_name} : tl.constexpr")
96-
97- if self.cooperative_reduction:
98- + triton_meta['signature']['RSPLIT'] = 'constexpr'
99- + triton_meta['constexprs']['RSPLIT'] = tree.numel
100- argdefs.append("RSPLIT : tl.constexpr")
101-
102- self.codegen_body()
103- diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py
104- index 8b8c29bbb1..3e5abaa824 100644
105- --- a/torch/_inductor/codegen/triton_utils.py
106- +++ b/torch/_inductor/codegen/triton_utils.py
107- @@ -157,13 +157,13 @@ def config_of(
108- raise NotImplementedError(f"unhandled {type(x)}: {x}")
109-
110- if config.triton.divisible_by_16:
111- - divisible_by_16 = tuple(
112- + divisible_by_16 = [tuple(
113- i
114- for i, arg in zip(indices, args)
115- if is_aligned(arg, alignment=16, include_tensor=True)
116- - )
117- + )]
118- else:
119- - divisible_by_16 = ()
120- + divisible_by_16 = []
121-
122- equal_to_1 = tuple(
123- i
124- @@ -172,5 +172,7 @@ def config_of(
125- and isinstance(arg.expr, (int, sympy.Integer))
126- and V.graph.sizevars.statically_known_equals(arg.expr, 1) # type: ignore[arg-type]
127- )
128- + if equal_to_1 != tuple():
129- + equal_to_1 = [equal_to_1]
130-
131- return AttrsDescriptorWrapper(divisible_by_16, equal_to_1)
132110diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py
133- index 2ab2b32635..5f08c3c0b7 100644
111+ index 2ab2b326354..42d76b8bf94 100644
134112--- a/torch/_inductor/codegen/wrapper.py
135113+++ b/torch/_inductor/codegen/wrapper.py
136- @@ -1535,16 +1535,21 @@ class PythonWrapperCodegen(CodeGen):
137-
138- signature: List[KernelArgType] = []
139- constants: Dict[str, Any] = {}
140- + constexprs = {}
141- non_constant_indices = []
142- equal_to_1_args: List[str] = []
143- for idx, key in enumerate(kernel.arg_names):
144- if key not in kwargs:
145- + if idx in kernel.constexprs:
146- + constexprs[key] = 'constexpr'
147- continue
148- arg = kwargs[key]
149- if idx in kernel.constexprs:
150- constants[key] = arg
151- + constexprs[key] = 'constexpr'
152- elif kwargs[key] is None:
153- constants[key] = None
154- + constexprs[key] = 'constexpr'
155- else:
156- non_constant_indices.append(idx)
157- if isinstance(arg, ir.TMADescriptor):
158- @@ -1596,9 +1601,8 @@ class PythonWrapperCodegen(CodeGen):
159- # causes CUDA errors in test_aot_inductor.test_triton_kernel_with_none_input.
160- # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307
114+ @@ -1598,7 +1598,6 @@ class PythonWrapperCodegen(CodeGen):
161115 # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
162- - "constants": {
163- + "constexprs": {
116+ "constants": {
164117 **constants,
165118- **dict.fromkeys(equal_to_1_args, 1),
166119 },
167120 "configs": [
168121 config_of(
169- @@ -1607,6 +1611,8 @@ class PythonWrapperCodegen(CodeGen):
170- )
171- ],
172- }
173- + for constexpr_name in constexprs.keys():
174- + triton_meta['signature'][constexpr_name] = 'constexpr'
175-
176- if restore_value_args:
177- triton_meta["restore_value"] = tuple(restore_value_args)
178122diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py
179- index 276c01f3f4..4e6e1ab9ce 100644
123+ index 276c01f3f42..5c633b7963b 100644
180124--- a/torch/_inductor/runtime/hints.py
181125+++ b/torch/_inductor/runtime/hints.py
182- @@ -53,6 +53,7 @@ if _is_triton_available():
126+ @@ -48,8 +48,8 @@ if _is_triton_available():
127+ ):
128+ # Prepare the arguments for AttrsDescriptor
129+ kwargs = {
130+ - "tt.divisibility": divisible_by_16,
131+ - "tt.equal_to": equal_to_1,
132+ + "tt.divisibility": tuple([(i,) for i in divisible_by_16]),
133+ + "tt.equal_to": tuple([(i,) for i in equal_to_1]),
183134 }
184135
185136 # Instantiate AttrsDescriptor with the prepared arguments
186- + # breakpoint()
187- res = AttrsDescriptor.from_dict(
188- {"arg_properties": kwargs, "cls": AttrsDescriptor.__name__}
189- )
190137diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py
191- index af8530e94d..a1935831e2 100644
138+ index af8530e94d0..1ec44de9806 100644
192139--- a/torch/_inductor/runtime/triton_heuristics.py
193140+++ b/torch/_inductor/runtime/triton_heuristics.py
194- @@ -407,6 +407,7 @@ class CachingAutotuner(KernelInterface):
195-
196- def _precompile_config(self, cfg: Config, warm_cache_only: bool):
197- """Ahead of time compile a given autotuner config."""
198- + # print(f"self.triton_meta: {self.triton_meta}")
199- compile_meta = copy.deepcopy(self.triton_meta)
200- for k, v in cfg.kwargs.items():
201- if self.device_props.type == "hip":
202- @@ -419,7 +420,7 @@ class CachingAutotuner(KernelInterface):
203- if k == "kpack":
204- compile_meta["kpack"] = v
205- continue
206- - compile_meta["constants"][k] = v
207- + compile_meta["constexprs"][k] = v
208- compile_meta["num_warps"] = cfg.num_warps
209- compile_meta["num_stages"] = cfg.num_stages
210- compile_meta["debug"] = self.inductor_meta.get(
211- @@ -435,12 +436,13 @@ class CachingAutotuner(KernelInterface):
141+ @@ -435,11 +435,22 @@ class CachingAutotuner(KernelInterface):
212142 else:
213143 triton_helpers.set_driver_to_gpu()
214144
215- + # print(compile_meta)
145+ + def fixup_signature(arg_names, signature, constants):
146+ + new_signature = {arg_name: None for arg_name in arg_names}
147+ + for arg_name in arg_names:
148+ + if arg_name in constants and arg_name not in signature:
149+ + # If it's not in the signature already, it's a constexpr
150+ + # argument that we need to add in
151+ + new_signature[arg_name] = "constexpr"
152+ + else:
153+ + new_signature[arg_name] = signature[arg_name]
154+ + return new_signature
155+ +
216156 if ASTSource:
217157 compile_args = (
218158 ASTSource(
219159 self.fn,
220- compile_meta["signature"],
221- - compile_meta["constants"],
222- + compile_meta["constexprs "],
160+ - compile_meta["signature"],
161+ + fixup_signature(self.fn.arg_names, compile_meta["signature"], compile_meta[" constants"]) ,
162+ compile_meta["constants "],
223163 compile_meta["configs"][0],
224164 ),
225- )
226- @@ -527,7 +529,7 @@ class CachingAutotuner(KernelInterface):
227- We also don't want to modify self.fn.
228-
229- We know that we removed something from the signature if:
230- - 1. It's in compile_meta["constants"]
231- + 1. It's in compile_meta["constexprs"]
232- 2. It isn't a constant we already know about
233- Note: The value of interest has already been added to compile_meta['constants'],
234- so we use self.fn.constexprs instead.
235- @@ -538,7 +540,7 @@ class CachingAutotuner(KernelInterface):
236- }
237- none_args = {
238- k
239- - for k, v in compile_meta["constants"].items()
240- + for k, v in compile_meta["constexprs"].items()
241- if v is None and k not in known_constants
242- }
243- none_args = none_args.difference(set(compile_meta["signature"].keys()))
244- @@ -548,12 +550,14 @@ class CachingAutotuner(KernelInterface):
245- for i, arg in enumerate(self.fn.arg_names)
246- if i not in self.fn.constexprs and arg not in none_args
247- ]
248- + # print(f"call_args: {call_args}")
249-
250- def_args = [
251- name
252- for name in self.fn.arg_names
253- if name not in cfg.kwargs and name not in none_args
254- ]
255- + # print(f"def_args: {def_args}\n")
256- binary_shared = (
257- binary.shared if hasattr(binary, "shared") else binary.metadata.shared
258- )
0 commit comments