Skip to content

Commit b963ab5

Browse files
davidberard98pytorchmergebot
authored andcommitted
[inductor][1/N] triton support post-pytorch#5512, main components (pytorch#145051)
Triton commit 5220 adds tuple support in Triton (changing the indexing format in AttrsDescriptor) and commit 5512 replaces AttrsDescriptor with raw tuples. This is an initial PR to add support for Triton versions after commit 5512 landed. The main changes in 5220 and 5512 that need to be supported: * AttrsDescriptor() gets replaced with a raw dict. The raw dict has the format `{(TUPLES): [["tt.divisibility", 16]]}`, where `(TUPLES)` is a tuple of indices, e.g. `((0,), (1,), (3,))` to indicate that args 0, 1, and 3 are divisible by 16. These indices are, themselves, represented as tuples to support nested inputs (e.g. an argument that's a tuple), but support for tuples is not implemented right now. * "signature" changes: the signature now contains _all_ args, including constexpr and constant args. * ASTSource now takes "constexprs" instead of "constants" - for example, equal-to-1 args are constants but not constexprs so we don't need to pass these args as "constants". What this PR supports: * Triton versions before Dec 9, 2024, and (partial support for) Triton versions after Jan 1, 2025 * (triton jan 1+) typical inductor-generated triton: updated AttrsDescriptor, signatures, constexpr/constant handling. What this PR doesn't support (TODO in follow-up PRs): * Triton versions between Dec 9, 2024 and before Jan 1, 2025 * (triton jan 1+) user-defined triton kernel support (this is implemented already in @anmyachev's patch) * (triton jan 1+) triton_helper support (failing in triton codegen - needs investigation) * (triton jan 1+) AOTI / cpp wrapper thanks to @anmyachev for patches in https://github.com/intel/intel-xpu-backend-for-triton/blob/main/scripts/pytorch.patch, which contains most of these changes already Pull Request resolved: pytorch#145051 Approved by: https://github.com/jansel
1 parent 714f643 commit b963ab5

File tree

8 files changed

+200
-58
lines changed

8 files changed

+200
-58
lines changed

test/inductor/test_torchinductor.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9271,14 +9271,23 @@ def run(x):
92719271
self.assertEqual(fw_code.count("halide_helpers.rand"), 2)
92729272
self.assertEqual(bw_code.count("halide_helpers.rand"), 0)
92739273
elif self.device == GPU_TYPE:
9274-
self.assertEqual(fw_code.count("tl.rand"), 2)
9274+
# the load_seed_offset arg can be 1 or non-1; depending on whether
9275+
# the triton signature specializes on 1 vs non-1, you might get 1
9276+
# or 2 kernels. In newer versions of triton, there's no specialization
9277+
# so we get only 1 kernel.
9278+
from torch._inductor.utils import triton_version_uses_attrs_dict
9279+
9280+
expected_kernels = 1 if triton_version_uses_attrs_dict() else 2
9281+
self.assertEqual(fw_code.count("tl.rand"), expected_kernels)
92759282
self.assertEqual(bw_code.count("tl.rand"), 0)
92769283
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4)
92779284

92789285
def test_randint_kernel_count(self):
92799286
if self.device != GPU_TYPE:
92809287
raise unittest.SkipTest("Only valid for GPU!")
92819288

9289+
from torch._inductor.utils import triton_version_uses_attrs_dict
9290+
92829291
@torch._dynamo.optimize_assert("inductor")
92839292
def fn1():
92849293
random_tensor1 = torch.randint(10, [32], device=self.device)
@@ -9289,7 +9298,15 @@ def fn1():
92899298
_, source_codes = run_and_get_code(fn1)
92909299
# cpp_wrapper does a 2-pass generation on GPU.
92919300
self.assertEqual(len(source_codes), 1)
9292-
self.assertEqual(source_codes[0].count("async_compile.triton"), 2)
9301+
9302+
# the load_seed_offset arg can be 1 or non-1; depending on whether
9303+
# the triton signature specializes on 1 vs non-1, you might get 1
9304+
# or 2 kernels. In newer versions of triton, there's no specialization
9305+
# so we get only 1 kernel.
9306+
expected_kernels = 1 if triton_version_uses_attrs_dict() else 2
9307+
self.assertEqual(
9308+
source_codes[0].count("async_compile.triton"), expected_kernels
9309+
)
92939310

92949311
def test_roll(self):
92959312
def fn(a):

torch/_inductor/codegen/common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,11 @@ def alias_of(self):
189189
return None
190190

191191

192+
@dataclasses.dataclass
193+
class ConstexprArg:
194+
name: str
195+
196+
192197
@dataclasses.dataclass
193198
class TMADescriptorArg:
194199
name: str

torch/_inductor/codegen/triton.py

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,15 @@
5656
sympy_product,
5757
sympy_subs,
5858
triton_type,
59+
triton_version_uses_attrs_dict,
5960
upcast_compute_type,
6061
)
6162
from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V
6263
from ..wrapper_benchmark import get_kernel_category_by_source_code
6364
from .block_analysis import BlockPatternMatcher
6465
from .common import (
6566
BackendFeature,
67+
ConstexprArg,
6668
CSE,
6769
CSEVariable,
6870
DeferredLine,
@@ -85,8 +87,8 @@
8587
)
8688
from .triton_utils import (
8789
config_of,
90+
non_constexpr_signature,
8891
should_unwrap_unspec_arg,
89-
signature_of,
9092
signature_to_meta,
9193
)
9294

@@ -3357,6 +3359,35 @@ def codegen_kernel(self, name=None):
33573359

33583360
mutated_args = sorted(mutated_args)
33593361

3362+
for tree in self.active_range_trees():
3363+
sizearg = SizeArg(f"{tree.prefix}numel", tree.numel)
3364+
signature.append(sizearg)
3365+
argdefs.append(sizearg.name)
3366+
# constexpr version causes issues, see
3367+
# https://github.com/pytorch/torchdynamo/pull/1362
3368+
# triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint(
3369+
# tree.numel
3370+
# )
3371+
# argdefs.append(f"{tree.prefix}numel: tl.constexpr")
3372+
3373+
def add_constexpr_arg(arg_name):
3374+
# new versions (but not old versions) of Triton need constexprs included in the signature
3375+
if triton_version_uses_attrs_dict():
3376+
signature.append(ConstexprArg(arg_name))
3377+
argdefs.append(f"{arg_name} : tl.constexpr")
3378+
3379+
for tree in self.range_trees:
3380+
if tree.is_reduction and self.persistent_reduction:
3381+
# Rn_BLOCK for persistent_reduction is defined in codegen_static_numels
3382+
continue
3383+
if tree.tensor_dim is None:
3384+
continue
3385+
3386+
add_constexpr_arg(f"{tree.prefix.upper()}BLOCK")
3387+
3388+
if self.cooperative_reduction:
3389+
add_constexpr_arg("RSPLIT")
3390+
33603391
triton_meta_signature = signature_to_meta(
33613392
signature, size_dtype=self.index_dtype, argdefs=argdefs
33623393
)
@@ -3390,42 +3421,19 @@ def codegen_kernel(self, name=None):
33903421
num_gb = self.estimate_kernel_num_bytes() / 1e9
33913422
inductor_meta["kernel_num_gb"] = num_gb
33923423

3393-
for tree in self.active_range_trees():
3394-
sizearg = SizeArg(f"{tree.prefix}numel", tree.numel)
3395-
signature.append(sizearg)
3396-
triton_meta_signature[sizearg.name] = signature_of(
3397-
sizearg, size_dtype=self.index_dtype
3398-
)
3399-
argdefs.append(f"{tree.prefix}numel")
3400-
# constexpr version causes issues, see
3401-
# https://github.com/pytorch/torchdynamo/pull/1362
3402-
# triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint(
3403-
# tree.numel
3404-
# )
3405-
# argdefs.append(f"{tree.prefix}numel: tl.constexpr")
34063424
triton_meta["configs"] = [config_of(signature)]
34073425

3408-
# Triton compiler includes equal_to_1 args into constants even
3409-
# when they are not constexpr. otherwise there may be a segfault
3410-
# during launching the Inductor-compiled Triton kernel.
3411-
# https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307
3412-
# https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
3413-
for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index]
3414-
triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index]
3426+
if not triton_version_uses_attrs_dict():
3427+
# Triton compiler includes equal_to_1 args into constants even
3428+
# when they are not constexpr. otherwise there may be a segfault
3429+
# during launching the Inductor-compiled Triton kernel.
3430+
# https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307
3431+
# https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
3432+
for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index]
3433+
triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index]
34153434

34163435
self.triton_meta = triton_meta
34173436

3418-
for tree in self.range_trees:
3419-
if tree.is_reduction and self.persistent_reduction:
3420-
# Rn_BLOCK for persistent_reduction is defined in codegen_static_numels
3421-
continue
3422-
if tree.tensor_dim is None:
3423-
continue
3424-
argdefs.append(f"{tree.prefix.upper()}BLOCK : tl.constexpr")
3425-
3426-
if self.cooperative_reduction:
3427-
argdefs.append("RSPLIT : tl.constexpr")
3428-
34293437
self.codegen_body()
34303438

34313439
for helper in self.helper_functions:
@@ -3457,7 +3465,9 @@ def codegen_kernel(self, name=None):
34573465
else:
34583466
tile_hint = ""
34593467
if len(size_hints) == 2:
3460-
if len(signature) == 4: # input, output and 2 args
3468+
if (
3469+
len(non_constexpr_signature(signature)) == 4
3470+
): # input, output and 2 args
34613471
tile_hint = "tile_hint=TileHint.SQUARE,"
34623472
else:
34633473
tile_hint = "tile_hint=TileHint.DEFAULT,"

torch/_inductor/codegen/triton_combo_kernel.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
from ..runtime.runtime_utils import next_power_of_2
1717
from ..runtime.triton_heuristics import grid_combo_kernels
1818
from ..scheduler import BaseSchedulerNode
19-
from ..utils import Placeholder
19+
from ..utils import Placeholder, triton_version_uses_attrs_dict
2020
from ..virtualized import V
2121
from .common import (
22+
ConstexprArg,
2223
DeferredLine,
2324
IndentedBuffer,
2425
Kernel,
@@ -727,8 +728,12 @@ def codegen_blocks(self, code: IndentedBuffer) -> None:
727728
code.splice(f"R0_BLOCK: tl.constexpr = {self.block_size_reduce}")
728729
code.splice(f"RBLOCK: tl.constexpr = {self.block_size_reduce}")
729730

730-
def add_blockd_to_args(self, argdefs: list[str]) -> list[str]:
731-
block_args = {}
731+
def get_block_args(self) -> list[ConstexprArg]:
732+
"""
733+
Calculate blocks from sub_kernels and range_trees.
734+
**Update self.block_args**
735+
Return the block args
736+
"""
732737
block_names = {}
733738
for sub_kernel in self.sub_kernels:
734739
# TODO: we assume all sub_kernels have the same block size
@@ -739,13 +744,10 @@ def add_blockd_to_args(self, argdefs: list[str]) -> list[str]:
739744
continue
740745
if tree.prefix == "x" and sub_kernel.no_x_dim:
741746
continue
742-
block_args[f"{tree.prefix.upper()}BLOCK : tl.constexpr"] = tree.prefix
743747
block_names[f"{tree.prefix.upper()}BLOCK"] = tree.prefix
744-
if self.enable_autotune:
745-
argdefs.extend(block_args)
746748
self.block_args = list(block_names.keys())
747749

748-
return argdefs
750+
return [ConstexprArg(x) for x in block_names.keys()]
749751

750752
def add_numel_to_args(self, argdefs: list[str], signature: list[Any]) -> list[str]:
751753
for num, sub_kernel in enumerate(self.sub_kernels):
@@ -830,7 +832,12 @@ def codegen_kernel(self, name: Optional[str] = None) -> str:
830832

831833
argdefs, _, signature, _ = self.args.python_argdefs()
832834
argdefs = self.add_numel_to_args(argdefs, signature)
833-
argdefs = self.add_blockd_to_args(argdefs)
835+
block_args = self.get_block_args()
836+
if self.enable_autotune:
837+
argdefs.extend([f"{x.name}: tl.constexpr" for x in block_args])
838+
if triton_version_uses_attrs_dict():
839+
signature.extend(block_args)
840+
834841
code.splice(
835842
self.jit_line(
836843
heuristics,

torch/_inductor/codegen/triton_utils.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,14 @@
99
from ..runtime.hints import AttrsDescriptorWrapper
1010
from ..utils import _type_of, expr_fits_within_32bit
1111
from ..virtualized import V
12-
from .common import KernelArgType, SizeArg, TensorArg, TMADescriptorArg, WorkspaceArg
12+
from .common import (
13+
ConstexprArg,
14+
KernelArgType,
15+
SizeArg,
16+
TensorArg,
17+
TMADescriptorArg,
18+
WorkspaceArg,
19+
)
1320

1421

1522
def should_unwrap_unspec_arg(name: str):
@@ -73,9 +80,20 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str:
7380
return _type_of(arg.dtype)
7481
if isinstance(arg, TMADescriptorArg):
7582
return "nvTmaDesc"
83+
if isinstance(arg, ConstexprArg):
84+
return "constexpr"
7685
raise NotImplementedError(f"unhandled {type(arg)}: {arg}")
7786

7887

88+
def non_constexpr_signature(signature):
89+
new_signature = []
90+
for arg in signature:
91+
if not isinstance(arg, ConstexprArg):
92+
new_signature.append(arg)
93+
94+
return new_signature
95+
96+
7997
def signature_to_meta(
8098
signature: list[KernelArgType],
8199
*,
@@ -152,7 +170,7 @@ def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool:
152170
if isinstance(x, WorkspaceArg):
153171
# We allocate the workspace ourselves, so it is always aligned
154172
return True
155-
if isinstance(x, TMADescriptorArg):
173+
if isinstance(x, (TMADescriptorArg, ConstexprArg)):
156174
return False
157175
raise NotImplementedError(f"unhandled {type(x)}: {x}")
158176

torch/_inductor/runtime/hints.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def _is_triton_available() -> bool:
4747
import triton.compiler.compiler
4848

4949
if hasattr(triton.backends.compiler, "AttrsDescriptor"):
50+
# Triton 3.2.0 - the second implementation
5051
from triton.backends.compiler import AttrsDescriptor
5152

5253
def AttrsDescriptorWrapper(
@@ -67,7 +68,8 @@ def AttrsDescriptorWrapper(
6768
assert res.property_values["tt.equal_to"] == 1
6869
return res
6970

70-
else:
71+
elif hasattr(triton.compiler.compiler, "AttrsDescriptor"):
72+
# Triton 3.0.0 - the original implementation
7173
from triton.compiler.compiler import AttrsDescriptor
7274

7375
def AttrsDescriptorWrapper(
@@ -83,6 +85,20 @@ def AttrsDescriptorWrapper(
8385
# Instantiate AttrsDescriptor with the prepared arguments
8486
return AttrsDescriptor(**kwargs)
8587

88+
else:
89+
# Triton in 2025:
90+
# note: there's also a range of triton commits not currently supported
91+
# from ~Dec 9, 2024 to Jan 1 2025, in which AttrsDescriptors are still
92+
# used, but the contents are different.
93+
94+
def AttrsDescriptorWrapper(
95+
divisible_by_16=None,
96+
equal_to_1=None,
97+
):
98+
return {
99+
tuple((x,) for x in divisible_by_16): [["tt.divisibility", 16]],
100+
}
101+
86102
else:
87103
# Define a namedtuple as a fallback when AttrsDescriptor is not available
88104
AttrsDescriptorWrapper = collections.namedtuple( # type: ignore[no-redef, name-match]

0 commit comments

Comments
 (0)