Skip to content

Commit 9de9d25

Browse files
blaine-risterpytorchmergebot
authored andcommitted
[Inductor-FX] Support custom triton kernels (pytorch#161474)
# Feature Add support for custom Triton kernels to the FX backend. This turned out not to require any new features, except for a minor change to handle `tl.constexpr` arguments which are not part of the autotuning config. # Caveat This may not cover every possible case. For example, we might need more features for autotuning custom Triton code. This PR entirely skips the [custom codegen ](https://github.com/pytorch/pytorch/blob/main/torch/_higher_order_ops/triton_kernel_wrap.py#L1034-L1039) for user-defined grid functions, but there may be edge cases requiring this logic. However, this PR seems to do a reasonable job as many of the grids end up being written into Inductor/Triton metadata and don't require special codegen. As a follow up, I'm planning to test this against all of AOTI's custom Triton kernel tests. # Test plan Added a CI test using a custom Triton kernel. Pull Request resolved: pytorch#161474 Approved by: https://github.com/angelayi
1 parent dbc903a commit 9de9d25

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

test/inductor/test_fxir_backend.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@
3535
)
3636

3737

38+
if HAS_GPU:
39+
import triton
40+
import triton.language as tl
41+
42+
3843
@requires_gpu()
3944
@config.patch(
4045
compile_threads=1,
@@ -544,6 +549,37 @@ def run(*args, **kwargs):
544549
if use_dynamic_shapes:
545550
self.assertEqual(type(shape[0]), torch.fx.Node)
546551

552+
def test_custom_triton(self):
553+
@triton.jit
554+
def add_kernel(
555+
in_ptr0,
556+
in_ptr1,
557+
out_ptr,
558+
n_elements,
559+
BLOCK_SIZE: tl.constexpr,
560+
):
561+
pid = tl.program_id(axis=0)
562+
block_start = pid * BLOCK_SIZE
563+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
564+
mask = offsets < n_elements
565+
x = tl.load(in_ptr0 + offsets, mask=mask)
566+
y = tl.load(in_ptr1 + offsets, mask=mask)
567+
output = x + y
568+
tl.store(out_ptr + offsets, output, mask=mask)
569+
570+
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
571+
output = torch.empty_like(x)
572+
n_elements = output.numel()
573+
574+
def grid(meta):
575+
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
576+
577+
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
578+
return output
579+
580+
args = [torch.randn(32, device=self.device) for _ in range(2)]
581+
self._compile_and_check(add, args)
582+
547583
def test_output_slice_view(self):
548584
"""
549585
Test when the output is a view of the input.

torch/_inductor/codegen/wrapper_fxir.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from torch.utils._sympy.reference import OptimizedPythonReferenceAnalysis
3434

3535
from .. import config, ir
36+
from ..runtime.triton_compat import Config
3637
from ..utils import LineContext
3738
from .common import (
3839
CodegenSymbol,
@@ -700,9 +701,40 @@ def node_to_tuning_arg(arg: Any) -> Any:
700701
kernel_name,
701702
)
702703

704+
triton_meta = tuner.triton_meta
705+
signature = triton_meta["signature"]
706+
707+
def add_constants_to_call_args(
708+
call_args: Sequence[Any], cfg: Config
709+
) -> tuple[Any, ...]:
710+
"""
711+
Add constant kwargs to the arg list.
712+
"""
713+
# Add args from the proper Triton signature.
714+
new_call_args = []
715+
call_arg_idx = 0
716+
constants = triton_meta["constants"]
717+
for arg_name in signature:
718+
# Config kwargs are tracked separately.
719+
if arg_name in cfg.kwargs:
720+
continue
721+
722+
try:
723+
new_arg = constants[arg_name]
724+
except KeyError:
725+
new_arg = call_args[call_arg_idx]
726+
call_arg_idx += 1
727+
new_call_args.append(new_arg)
728+
729+
# Add Inductor's extra call args to the end.
730+
new_call_args.extend(call_args[call_arg_idx:])
731+
732+
return tuple(new_call_args)
733+
703734
kernel_config = tuner.compile_results[0].config
735+
call_args = add_constants_to_call_args(call_args, kernel_config)
704736
call_args, grid = tuner._interpret_args_grid(call_args, kernel_config)
705-
call_kwargs = dict(zip(tuner.triton_meta["signature"], call_args))
737+
call_kwargs = dict(zip(signature, call_args))
706738
call_kwargs.update(kernel_config.kwargs)
707739

708740
wrapper_grid = [tuple(self._generate_sym_nodes(grid))]

0 commit comments

Comments
 (0)