Skip to content

Commit 9256bb3

Browse files
davidberard98whitneywhtsang
authored andcommitted
[FRONTEND] Allow JITFunctions as arguments to other JITFunctions (#5723)
This PR allows a call to a JITFunction to pass another JITFunction as an argument. For example: ```python @triton.jit def fn_a(x): ... @triton.jit def fn_b(x, fn): ... @triton.jit def fn_c(x): return fn_b(x, fn_a) # fn_a (a JITFunction) is passed as an argument to fn_b (another JITFunction) ``` Prior to #5220, this worked. After #5220, the user needs to annotate the JITFunctions with @triton.constexpr manually (until this PR). Use case: Inductor has some generic helper functions for implementing scans (e.g. exclusive_scan_decoupled_lookback) which take a `combine_fn` to implement the combination function (similar to tl.reduce). These helper functions have stopped working after #5220. https://github.com/pytorch/pytorch/blob/01a4d86b31365cfb484dc17885c9a7ee09c235ab/torch/_inductor/runtime/triton_helpers.py#L321
1 parent aa9630a commit 9256bb3

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

python/test/unit/language/test_core.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6940,6 +6940,35 @@ def inject_layout(ir, src: torch.Tensor, axis, indices: torch.Tensor, src_layout
69406940
torch.testing.assert_close(output, ref, rtol=0, atol=0)
69416941

69426942

6943+
@triton.jit
6944+
def mul_jit_function(x, y):
6945+
return x * y
6946+
6947+
6948+
@triton.jit
6949+
def apply_binary_op(x, combine_op):
6950+
return combine_op(x, x)
6951+
6952+
6953+
def test_jit_function_arg(device):
6954+
6955+
@triton.jit
6956+
def square_kernel_jit_function(in_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
6957+
offsets = tl.arange(0, BLOCK_SIZE)
6958+
in_data = tl.load(in_ptr + offsets)
6959+
out_data = apply_binary_op(in_data, mul_jit_function) # pass a JITFunction into another JITFunction
6960+
tl.store(out_ptr + offsets, out_data)
6961+
6962+
BLOCK_SIZE = 16
6963+
x = torch.full((BLOCK_SIZE, ), 3.0, device=device)
6964+
out = torch.empty((BLOCK_SIZE, ), device=device)
6965+
expect = torch.full((BLOCK_SIZE, ), 9.0, dtype=x.dtype, device=device)
6966+
6967+
square_kernel_jit_function[(1, )](x, out, BLOCK_SIZE)
6968+
6969+
torch.testing.assert_close(out, expect)
6970+
6971+
69436972
@pytest.mark.interpreter
69446973
def test_zero_strided_tensors(device):
69456974

python/triton/compiler/code_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1153,7 +1153,7 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
11531153
args = inspect.getcallargs(fn.fn, *args, **kwargs)
11541154
args = [args[name] for name in fn.arg_names]
11551155
for i, arg in enumerate(args):
1156-
if isinstance(arg, (language.dtype, float, int, bool)):
1156+
if isinstance(arg, (language.dtype, float, int, bool, JITFunction)):
11571157
args[i] = language.core.constexpr(arg)
11581158
args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x))
11591159
args_cst = {path: get_iterable_path(args, path) for path in args_cst}

0 commit comments

Comments
 (0)