Skip to content

Commit b2c89bc

Browse files
davidberard98pytorchmergebot
authored andcommitted
[inductor][2/N] triton support post-pytorch#5512, user-defined triton kernels (pytorch#145348)
Triton commit 5220 adds tuple support in Triton (changing the indexing format in AttrsDescriptor) and commit 5512 replaces AttrsDescriptor with raw tuples. This PR fixes user-defined triton kernel handling (in most cases) for these new triton commits. What this PR fixes: * in triton_kernel_wrap.py, AST->TTIR parsing was to be updated for the new triton API * ir.py - don't remove None args when using newer triton versions * wrapper.py - update signature & constant handling What this doesn't fix: * correct None handling - I want to do a closer look at constant handling (including None, equal_to_1, and other constants). * cpp wrapper (which needs to be fixed for both user-defined triton kernels and inductor-generated kernels) test/inductor/test_triton_kernels.py passed on triton commit 74de6b46, with the exception of three tests (those shown here: pytorch@1374074) Pull Request resolved: pytorch#145348 Approved by: https://github.com/jansel ghstack dependencies: pytorch#145051
1 parent b963ab5 commit b2c89bc

File tree

5 files changed

+193
-83
lines changed

5 files changed

+193
-83
lines changed

test/inductor/test_triton_kernels.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,6 +1231,8 @@ def f(x):
12311231
@requires_gpu
12321232
@common_utils.parametrize("dynamic", [False, True])
12331233
def test_triton_kernel_equal_to_1_arg(self, dynamic):
1234+
from torch._inductor.utils import triton_version_uses_attrs_dict
1235+
12341236
@triton.jit
12351237
def add_kernel_half_n_elements(
12361238
in_ptr0,
@@ -1263,17 +1265,25 @@ def f(x, y):
12631265
torch.compile(f, dynamic=dynamic), x, y
12641266
)
12651267

1266-
if dynamic:
1267-
# when half_n_elements passed to the Triton kernel is
1268-
# dynamic, equal_to_1 specializaiton can't be enforced
1269-
self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0])
1268+
if triton_version_uses_attrs_dict():
1269+
self.assertFalse("equal_to" in sources[0])
12701270
else:
1271-
self.assertTrue(_triton_get_ast_equal_to_str((3,)) in sources[0])
1271+
if dynamic:
1272+
# when half_n_elements passed to the Triton kernel is
1273+
# dynamic, equal_to_1 specializaiton can't be enforced
1274+
1275+
# also, equal_to_1 specialization doesn't occur (or appear in the signature)
1276+
# for newer versions ofo triton (i.e. the ones where triton_version_uses_attrs_dict() == True)
1277+
self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0])
1278+
else:
1279+
self.assertTrue(_triton_get_ast_equal_to_str((3,)) in sources[0])
12721280
self.assertEqual(compiled_out, eager_out)
12731281

12741282
@requires_gpu
12751283
@common_utils.parametrize("dynamic", [False, True])
12761284
def test_triton_kernel_equal_to_1_float_arg(self, dynamic):
1285+
from torch._inductor.utils import triton_version_uses_attrs_dict
1286+
12771287
def f(x, y):
12781288
out = torch.empty_like(x)
12791289
n_elements = x.numel()
@@ -1297,7 +1307,8 @@ def f(x, y):
12971307

12981308
# float 1.0 (both literal or symbolic)
12991309
# should not be added to equal_to_1
1300-
self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0])
1310+
if not triton_version_uses_attrs_dict():
1311+
self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0])
13011312
self.assertEqual(compiled_out, eager_out)
13021313

13031314
@requires_gpu

torch/_higher_order_ops/triton_kernel_wrap.py

Lines changed: 77 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,19 @@ def generate_ttir(
172172
"""
173173
import sympy
174174
import triton
175+
import triton.runtime.jit
175176
from triton.compiler.compiler import ASTSource
176177
from triton.runtime.autotuner import Autotuner
177178
from triton.runtime.jit import JITFunction
178179

180+
from torch._inductor.utils import (
181+
get_triton_attrs_descriptor_version,
182+
triton_version_uses_attrs_dict,
183+
TritonAttrsDescriptorVersion,
184+
)
185+
186+
triton_version = get_triton_attrs_descriptor_version()
187+
179188
import torch._inductor.ir
180189
from torch._subclasses.fake_tensor import FakeTensor
181190

@@ -225,26 +234,78 @@ def generate_ttir(
225234
]
226235

227236
def _get_specialization(args): # type: ignore[no-untyped-def]
228-
try:
237+
# Support multiple triton versions.
238+
# This code basically copies JITFunction.run() logic to get the attrs to construct an ASTSource.
239+
if triton_version == TritonAttrsDescriptorVersion.V1_COMPILER:
240+
return kernel._get_config(*args)
241+
elif triton_version in {
242+
TritonAttrsDescriptorVersion.V2_BACKENDS,
243+
TritonAttrsDescriptorVersion.V3_BACKENDS_TUPLE,
244+
}:
229245
from triton.backends.compiler import AttrsDescriptor # noqa: F401
230246

231247
target = triton.runtime.driver.active.get_current_target()
232-
backend = triton.compiler.compiler.make_backend(target)
233-
return backend.get_attrs_descriptor(args, kernel.params)
234-
except ImportError:
235-
return kernel._get_config(*args)
248+
backend_ = triton.compiler.compiler.make_backend(target)
249+
return backend_.get_attrs_descriptor(args, kernel.params)
250+
else:
251+
assert (
252+
get_triton_attrs_descriptor_version()
253+
== TritonAttrsDescriptorVersion.V4_DICT
254+
)
255+
from triton._utils import find_paths_if, get_iterable_path
256+
from triton.runtime.jit import specialize_impl
257+
258+
# logic is copied from: binder = create_function_from_signature(self.signature, self.params, backend)
259+
attrvals = []
260+
for arg, kp in zip(args, kernel.params):
261+
if kp.is_constexpr:
262+
attrvals.append(arg)
263+
else:
264+
spec = specialize_impl(
265+
arg,
266+
specialize_extra=backend.get_arg_specialization,
267+
is_const=kp.is_const,
268+
specialize_value=not kp.do_not_specialize,
269+
align=not kp.do_not_specialize_on_alignment,
270+
)
271+
attrvals.append(spec[1])
272+
273+
attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
274+
attrs = {
275+
k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs
276+
}
277+
return attrs
236278

237279
specialization = _get_specialization(ordered_args.values())
238280
constants = {
239281
name: arg for name, arg in ordered_args.items() if not isinstance(arg, Tensor)
240282
}
241283

242-
# Build kernel signature -- doesn't include constexpr arguments.
243-
signature = {
244-
name: kernel._type_of(kernel._key_of(arg))
245-
for i, (name, arg) in enumerate(ordered_args.items())
246-
if i not in kernel.constexprs
247-
}
284+
if (mangle_type := getattr(triton.runtime.jit, "mangle_type", None)) is not None:
285+
286+
def get_signature_value(idx: int, arg: Any) -> str:
287+
if kernel.params[idx].is_constexpr:
288+
return "constexpr"
289+
return mangle_type(arg)
290+
291+
else:
292+
293+
def get_signature_value(idx: int, arg: Any) -> str:
294+
return kernel._type_of(kernel.key_of(arg))
295+
296+
if triton_version_uses_attrs_dict():
297+
# In newer versions of Triton, the signature includes constexpr args
298+
signature = {
299+
name: get_signature_value(i, arg)
300+
for i, (name, arg) in enumerate(ordered_args.items())
301+
}
302+
else:
303+
# In older versions of Triton, the signature does not include constexpr args
304+
signature = {
305+
name: get_signature_value(i, arg)
306+
for i, (name, arg) in enumerate(ordered_args.items())
307+
if i not in kernel.constexprs
308+
}
248309

249310
triton._C.libtriton.ir.load_dialects(context)
250311
backend.load_dialects(context)
@@ -254,13 +315,17 @@ def _get_specialization(args): # type: ignore[no-untyped-def]
254315
# Triton changes ASTSource.make_ir to take 3/4 arguments. Handle
255316
# backward compatibility here.
256317
make_ir_sig_params = len(inspect.signature(src.make_ir).parameters)
318+
get_codegen_implementation_sig_params = len(
319+
inspect.signature(backend.get_codegen_implementation).parameters
320+
)
257321
if make_ir_sig_params == 2:
258322
ttir_module = src.make_ir(options, context)
259323
elif make_ir_sig_params == 3:
260324
codegen_fns = backend.get_codegen_implementation()
261325
ttir_module = src.make_ir(options, codegen_fns, context)
262326
else:
263-
codegen_fns = backend.get_codegen_implementation()
327+
codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []
328+
codegen_fns = backend.get_codegen_implementation(*codegen_args)
264329
module_map = backend.get_module_map()
265330
ttir_module = src.make_ir(options, codegen_fns, module_map, context)
266331
if not ttir_module.verify():

torch/_inductor/codegen/triton_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from .. import config
99
from ..runtime.hints import AttrsDescriptorWrapper
10-
from ..utils import _type_of, expr_fits_within_32bit
10+
from ..utils import _type_of, expr_fits_within_32bit, triton_version_uses_attrs_dict
1111
from ..virtualized import V
1212
from .common import (
1313
ConstexprArg,
@@ -55,9 +55,15 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str:
5555
return tye
5656
if isinstance(arg, SizeArg):
5757
if arg.expr is None:
58-
# From triton/runtime/jit.py
59-
# `None` is nullptr. Implicitly convert to *i8.
60-
return "*i8"
58+
if triton_version_uses_attrs_dict():
59+
# In newer versions of Triton, the signature includes "None" args
60+
# and their type is marked as "constexpr"
61+
return "constexpr"
62+
else:
63+
# In older versions of Triton...
64+
# From triton/runtime/jit.py
65+
# `None` is nullptr. Implicitly convert to *i8.
66+
return "*i8"
6167
elif isinstance(arg.expr, (float, sympy.Float)):
6268
return "fp32"
6369

torch/_inductor/codegen/wrapper.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
LineContext,
4343
sympy_product,
4444
sympy_str,
45+
triton_version_uses_attrs_dict,
4546
)
4647
from ..virtualized import V
4748
from .common import (
@@ -1577,63 +1578,85 @@ def define_user_defined_triton_kernel(
15771578

15781579
original_name = kernel.__name__
15791580

1580-
from .common import KernelArgType, SizeArg, TensorArg, TMADescriptorArg
1581+
from .common import (
1582+
ConstexprArg,
1583+
KernelArgType,
1584+
SizeArg,
1585+
TensorArg,
1586+
TMADescriptorArg,
1587+
)
15811588

15821589
signature: list[KernelArgType] = []
15831590
constants: dict[str, Any] = {}
15841591
non_constant_indices = []
15851592
equal_to_1_args: list[str] = []
1593+
1594+
def add_to_signature(idx, arg):
1595+
signature.append(arg)
1596+
non_constant_indices.append(idx)
1597+
15861598
for idx, key in enumerate(kernel.arg_names):
1599+
if idx in kernel.constexprs:
1600+
if key in kwargs:
1601+
constants[key] = kwargs[key]
1602+
if triton_version_uses_attrs_dict():
1603+
add_to_signature(idx, ConstexprArg(name=key))
1604+
continue
1605+
15871606
if key not in kwargs:
15881607
continue
1608+
15891609
arg = kwargs[key]
1590-
if idx in kernel.constexprs:
1591-
constants[key] = arg
1592-
elif kwargs[key] is None:
1610+
1611+
if kwargs[key] is None:
15931612
constants[key] = None
15941613
else:
1595-
non_constant_indices.append(idx)
15961614
if isinstance(arg, ir.TMADescriptor):
1597-
signature.append(
1615+
add_to_signature(
1616+
idx,
15981617
TMADescriptorArg(
15991618
name=key,
1600-
)
1619+
),
16011620
)
16021621
elif isinstance(arg, ir.Buffer):
1603-
signature.append(
1622+
add_to_signature(
1623+
idx,
16041624
TensorArg(
16051625
name=key,
16061626
buffer=arg.get_name(),
16071627
dtype=arg.get_dtype(),
1608-
)
1628+
),
16091629
)
16101630
elif isinstance(arg, ir.ReinterpretView):
16111631
# for ReinterpretView we use the underlying
16121632
# buffer name and note the (possibly non-zero)
16131633
# offset relative to the underlying buffer
1614-
signature.append(
1634+
add_to_signature(
1635+
idx,
16151636
TensorArg(
16161637
name=key,
16171638
buffer=arg.data.get_name(),
16181639
dtype=arg.get_dtype(),
16191640
offset=arg.layout.offset,
1620-
)
1641+
),
16211642
)
16221643
else:
1623-
signature.append(SizeArg(key, arg))
1644+
add_to_signature(idx, SizeArg(key, arg))
16241645
if isinstance(
16251646
arg, (int, sympy.Integer)
16261647
) and V.graph.sizevars.statically_known_equals(
16271648
arg, 1 # type: ignore[arg-type]
16281649
):
16291650
equal_to_1_args.append(key)
1651+
1652+
triton_signature = signature_to_meta(
1653+
signature,
1654+
size_dtype=None, # try to infer based on symints
1655+
indices=non_constant_indices,
1656+
argdefs=kernel.arg_names,
1657+
)
16301658
triton_meta: dict[str, Any] = {
1631-
"signature": signature_to_meta(
1632-
signature,
1633-
size_dtype=None, # try to infer based on symints
1634-
indices=non_constant_indices,
1635-
argdefs=kernel.arg_names,
1636-
),
1659+
"signature": triton_signature,
16371660
"device": DeviceProperties.create(V.graph.get_current_device_or_throw()),
16381661
# Triton compiler includes equal_to_1 args into constants even
16391662
# when they are not constexpr. otherwise there may be a segfault

0 commit comments

Comments
 (0)