Skip to content

Commit 915f62d

Browse files
authored
Allow customization of the subscript operator for triton values (#7239)
This allows users to provide custom __setitem__ and __getitem__ functions in order to override the subscript operator for their triton classes. # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent bfaab80 commit 915f62d

File tree

2 files changed

+112
-23
lines changed

2 files changed

+112
-23
lines changed

python/test/unit/language/test_frontend.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,18 @@ def get_second(self, _semantic=None):
3232
def unpack(self):
3333
return self.get_first(), self.get_second()
3434

35+
def __getitem__(self, ind: tl.constexpr, _semantic=None):
36+
if ind == 0:
37+
return self.first
38+
assert ind == 1
39+
return self.second
40+
41+
def __setitem__(self, ind: tl.constexpr, value, _semantic=None):
42+
if ind == 0:
43+
self.first = value
44+
assert ind == 1
45+
self.second = value
46+
3547

3648
@filecheck_test
3749
@triton.jit
@@ -62,6 +74,47 @@ def test_augassign_attribute():
6274
anchor(pair)
6375

6476

77+
@filecheck_test
78+
@triton.jit
79+
def test_retrieve_item():
80+
# CHECK-LABEL: test_retrieve_item
81+
# CHECK: %c11_i32 = arith.constant 11 : i32
82+
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
83+
scalar = 11
84+
pair = Pair(tl.arange(0, 4), scalar)
85+
# CHECK-NEXT: call @{{.*}}anchor{{.*}}(%c11_i32)
86+
anchor(pair[1])
87+
88+
89+
@filecheck_test
90+
@triton.jit
91+
def test_assign_item():
92+
# CHECK-LABEL: test_assign_item
93+
# CHECK: %c11_i32 = arith.constant 11 : i32
94+
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
95+
scalar = 11
96+
pair = Pair(tl.arange(0, 4), scalar)
97+
# CHECK: %c42_i32 = arith.constant 42 : i32
98+
pair[1] = 42
99+
# CHECK-NEXT: call @{{.*}}anchor{{.*}}([[RANGE]], %c42_i32)
100+
anchor(pair)
101+
102+
103+
@filecheck_test
104+
@triton.jit
105+
def test_augassign_item():
106+
# CHECK-LABEL: test_augassign_item
107+
# CHECK: %c11_i32 = arith.constant 11 : i32
108+
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
109+
scalar = 11
110+
pair = Pair(tl.arange(0, 4), scalar)
111+
# CHECK: %c42_i32 = arith.constant 42 : i32
112+
# CHECK: [[VALUE:%.*]] = arith.addi %c11_i32, %c42_i32
113+
pair[1] += 42
114+
# CHECK-NEXT: call @{{.*}}anchor{{.*}}([[RANGE]], [[VALUE]])
115+
anchor(pair)
116+
117+
65118
@filecheck_test
66119
@triton.jit
67120
def test_jit_method():
@@ -78,6 +131,32 @@ def test_jit_method():
78131
anchor(b)
79132

80133

134+
@tl.core._aggregate
135+
class TypeWithJitGetItem:
136+
value: tl.tensor
137+
138+
def __init__(self, value):
139+
self.value = value
140+
141+
@triton.jit
142+
def __getitem__(self, ind):
143+
return self.value
144+
145+
146+
@filecheck_test
147+
@triton.jit
148+
def test_jit_getitem():
149+
# CHECK-LABEL: test_jit_getitem
150+
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
151+
v = TypeWithJitGetItem(tl.arange(0, 4))
152+
# CHECK: [[V:%.*]] = tt.call [[METHOD:@.*__getitem__.*]]([[RANGE]])
153+
a = v[0]
154+
# CHECK: call @{{.*}}anchor{{.*}}([[V]])
155+
anchor(a)
156+
# CHECK: tt.func private [[METHOD]]([[ARG0:%.*]]:
157+
# CHECK: tt.return [[ARG0]]
158+
159+
81160
@tl.core._aggregate
82161
class TypeWithBuiltinInitializer:
83162
value: tl.tensor

python/triton/compiler/code_generator.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,16 +1033,15 @@ def visit_Subscript_Load(self, node):
10331033
assert isinstance(node.ctx, ast.Load)
10341034
lhs = self.visit(node.value)
10351035
slices = self.visit(node.slice)
1036-
if _is_triton_tensor(lhs):
1037-
return lhs.__getitem__(slices, _semantic=self.semantic)
1036+
if _is_triton_value(lhs):
1037+
return self.call_Method(node, lhs.__getitem__, lhs, [slices], {})
10381038
return lhs[slices]
10391039

10401040
def visit_Subscript_Store(self, node, value):
10411041
assert isinstance(node.ctx, ast.Store)
10421042
lhs = self.visit(node.value)
10431043
slices = self.visit(node.slice)
1044-
assert isinstance(lhs, language.tuple)
1045-
lhs.__setitem__(slices, value)
1044+
self.call_Method(node, lhs.__setitem__, lhs, [slices, value], {})
10461045

10471046
def visit_Subscript(self, node):
10481047
return self.visit_Subscript_Load(node)
@@ -1238,32 +1237,18 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
12381237
handles = [call_op.get_result(i) for i in range(call_op.get_num_results())]
12391238
return next(unflatten_ir_values(handles, [callee_ret_type]))
12401239

1241-
def visit_Call(self, node):
1242-
fn = _unwrap_if_constexpr(self.visit(node.func))
1243-
if not isinstance(fn, BoundJITMethod):
1244-
static_implementation = self.statically_implemented_functions.get(fn)
1245-
if static_implementation is not None:
1246-
return static_implementation(self, node)
1247-
1248-
mur = getattr(fn, '_must_use_result', False)
1249-
if mur and getattr(node, '_is_unused', False):
1250-
error_message = ["The result of %s is not being used." % ast.unparse(node.func)]
1251-
if isinstance(mur, str):
1252-
error_message.append(mur)
1253-
raise CompilationError(self.jit_fn.src, node, " ".join(error_message))
1254-
1255-
kws = dict(self.visit(keyword) for keyword in node.keywords)
1256-
args = [self.visit(arg) for arg in node.args]
1257-
args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
1240+
def call_Function(self, node, fn, args, kws):
12581241
if isinstance(fn, BoundJITMethod):
12591242
args.insert(0, fn.__self__)
12601243
fn = fn.__func__
12611244
if isinstance(fn, JITFunction):
12621245
_check_fn_args(node, fn, args)
12631246
return self.call_JitFunction(fn, args, kws)
12641247
if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn):
1265-
extra_kwargs = {"_semantic": self.semantic}
1248+
extra_kwargs = dict()
12661249
sig = inspect.signature(fn)
1250+
if '_semantic' in sig.parameters:
1251+
extra_kwargs["_semantic"] = self.semantic
12671252
if '_generator' in sig.parameters:
12681253
extra_kwargs['_generator'] = self
12691254
try:
@@ -1281,13 +1266,38 @@ def visit_Call(self, node):
12811266
# itself). But when calling a function, we raise as `from e` to
12821267
# preserve the traceback of the original error, which may e.g.
12831268
# be in core.py.
1284-
raise CompilationError(self.jit_fn.src, node, None) from e
1269+
raise CompilationError(self.jit_fn.src, node, str(e)) from e
12851270

12861271
if fn in self.builtin_namespace.values():
12871272
args = map(_unwrap_if_constexpr, args)
12881273
ret = fn(*args, **kws)
12891274
return _apply_to_tuple_values(ret, lambda x: x) if _is_namedtuple(type(ret)) else ret
12901275

1276+
def call_Method(self, node, fn, fn_self, args, kws):
1277+
if isinstance(fn, JITFunction):
1278+
args.insert(0, fn_self)
1279+
return self.call_Function(node, fn, args, kws)
1280+
1281+
def visit_Call(self, node):
1282+
fn = _unwrap_if_constexpr(self.visit(node.func))
1283+
if not isinstance(fn, BoundJITMethod):
1284+
static_implementation = self.statically_implemented_functions.get(fn)
1285+
if static_implementation is not None:
1286+
return static_implementation(self, node)
1287+
1288+
mur = getattr(fn, '_must_use_result', False)
1289+
if mur and getattr(node, '_is_unused', False):
1290+
error_message = ["The result of %s is not being used." % ast.unparse(node.func)]
1291+
if isinstance(mur, str):
1292+
error_message.append(mur)
1293+
raise CompilationError(self.jit_fn.src, node, " ".join(error_message))
1294+
1295+
kws = dict(self.visit(keyword) for keyword in node.keywords)
1296+
args = [self.visit(arg) for arg in node.args]
1297+
args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
1298+
1299+
return self.call_Function(node, fn, args, kws)
1300+
12911301
def visit_Constant(self, node):
12921302
return constexpr(node.value)
12931303

0 commit comments

Comments
 (0)