Skip to content

Commit d25fc5f

Browse files
authored
[Frontend] Support binding self to JITFunction when they are methods (#6963)
📚 Stacked PRs 📚 * triton-lang/triton#6970 * ➡️ triton-lang/triton#6963 This PR makes `base_value.method` return a BoundJITMethod which keeps `base_value` to be passed as `__self__`.
1 parent 90cdc01 commit d25fc5f

File tree

2 files changed

+51
-7
lines changed

2 files changed

+51
-7
lines changed

python/test/unit/language/test_frontend.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,7 @@ def run_filecheck(name, module_str, check_template):
4141
raise ValueError(matcher.stderr.getvalue())
4242

4343

44-
def run_filecheck_test(kernel_fn):
45-
assert isinstance(kernel_fn, triton.runtime.JITFunction)
46-
check_template = inspect.getsource(kernel_fn.fn)
47-
if check_template is None:
48-
raise ValueError("kernel function must have a docstring with FileCheck template")
44+
def run_parser(kernel_fn):
4945
sigkeys = [x.name for x in kernel_fn.params]
5046
sigvals = [f"arg{i}" for i in range(len(sigkeys))]
5147
signature = {k: v for (k, v) in zip(sigkeys, sigvals)}
@@ -59,7 +55,15 @@ def run_filecheck_test(kernel_fn):
5955
options = stub_backend.parse_options(dict(**extra_options))
6056
codegen_fns = stub_backend.get_codegen_implementation(options)
6157
module_map = stub_backend.get_module_map()
62-
mlir_module = src.make_ir(options, codegen_fns, module_map, context)
58+
return src.make_ir(options, codegen_fns, module_map, context)
59+
60+
61+
def run_filecheck_test(kernel_fn):
62+
assert isinstance(kernel_fn, triton.runtime.JITFunction)
63+
check_template = inspect.getsource(kernel_fn.fn)
64+
if check_template is None:
65+
raise ValueError("kernel function must have a docstring with FileCheck template")
66+
mlir_module = run_parser(kernel_fn)
6367

6468
run_filecheck("placeholder", str(mlir_module), check_template)
6569

@@ -142,6 +146,17 @@ def _flatten_ir(self, handles: List[ir.value]) -> None:
142146
self.first._flatten_ir(handles)
143147
self.second._flatten_ir(handles)
144148

149+
@triton.jit
150+
def get_first(self):
151+
return self.first
152+
153+
def get_second(self, _builder=None):
154+
return self.second
155+
156+
@triton.jit
157+
def unpack(self):
158+
return self.get_first(), self.get_second()
159+
145160

146161
@tl.core.builtin
147162
def pair_value_ctor(first, second, _builder=None):
@@ -160,3 +175,19 @@ def test_assign_attribute():
160175
# CHECK-NEXT: call @"anchor{{.*}}"([[RANGE]], %c42_i32)
161176
pair.second = 42
162177
anchor(pair)
178+
179+
180+
@filecheck_test
181+
@triton.jit
182+
def test_jit_method():
183+
# CHECK-LABEL: test_jit_method
184+
# CHECK: %c11_i32 = arith.constant 11 : i32
185+
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
186+
scalar = 11
187+
# CHECK: [[V:%.*]]:2 = tt.call @"unpack{{.*}}"([[RANGE]], %c11_i32)
188+
pair = pair_value_ctor(tl.arange(0, 4), scalar)
189+
a, b = pair.unpack()
190+
# CHECK: call @anchor{{.*}}([[V]]#0)
191+
anchor(a)
192+
# CHECK: call @anchor{{.*}}([[V]]#1)
193+
anchor(b)

python/triton/compiler/code_generator.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import warnings
55
import textwrap
66
import itertools
7+
from dataclasses import dataclass
78
from types import ModuleType
89
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, Iterable, List
910

@@ -271,6 +272,12 @@ def make_template(ty):
271272
return vals
272273

273274

275+
@dataclass(frozen=True)
276+
class BoundJITMethod:
277+
__self__: base_value
278+
__func__: JITFunction
279+
280+
274281
class CodeGenerator(ast.NodeVisitor):
275282

276283
def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, options, codegen_fns, module_map,
@@ -1247,6 +1254,9 @@ def visit_Call(self, node):
12471254
kws = dict(self.visit(keyword) for keyword in node.keywords)
12481255
args = [self.visit(arg) for arg in node.args]
12491256
args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
1257+
if isinstance(fn, BoundJITMethod):
1258+
args.insert(0, fn.__self__)
1259+
fn = fn.__func__
12501260
if isinstance(fn, JITFunction):
12511261
_check_fn_args(node, fn, args)
12521262
return self.call_JitFunction(fn, args, kws)
@@ -1338,7 +1348,10 @@ def visit_Attribute(self, node):
13381348
lhs = self.visit(node.value)
13391349
if _is_triton_tensor(lhs) and node.attr == "T":
13401350
return semantic.permute(lhs, (1, 0), builder=self.builder)
1341-
return getattr(lhs, node.attr)
1351+
attr = getattr(lhs, node.attr)
1352+
if _is_triton_value(lhs) and isinstance(attr, JITFunction):
1353+
return BoundJITMethod(lhs, attr)
1354+
return attr
13421355

13431356
def visit_Expr(self, node):
13441357
node.value._is_unused = True

0 commit comments

Comments
 (0)