Skip to content

Commit 37817d7

Browse files
authored
[FRONTEND] enable construction of named tuples inside triton functions (#5519)
1 parent dc261bf commit 37817d7

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

python/test/unit/language/test_tuple.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,19 +114,23 @@ class Tensor(NamedTuple):
114114

115115

116116
@triton.jit
117-
def _namedtuple_kernel(closure, X, Y, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
117+
def _namedtuple_mask_func(Tensor, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
118118
offs_m = tl.arange(0, BLOCK_M)
119119
offs_n = tl.arange(0, BLOCK_N)
120-
# load x
121-
mask_x = (offs_m[:, None] < X.shape[0]) & (offs_n[None, :] < X.shape[1])
120+
mask = (offs_m[:, None] < Tensor.shape[0]) & (offs_n[None, :] < Tensor.shape[1])
121+
return mask
122+
123+
124+
@triton.jit
125+
def _namedtuple_kernel(closure, _X, Y, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
126+
offs_m = tl.arange(0, BLOCK_M)
127+
offs_n = tl.arange(0, BLOCK_N)
128+
X = Tensor(shape=_X.shape, ptr=_X.ptr, stride=_X.stride)
122129
Xs = X.ptr + offs_m[:, None] * X.stride[0] + offs_n[None, :] * X.stride[1]
123-
x = tl.load(Xs, mask=mask_x, other=0)
124-
# compute y
125-
y = closure.fn(x, *closure.captured)
126-
# store y
127-
mask_y = (offs_m[:, None] < Y.shape[0]) & (offs_n[None, :] < Y.shape[1])
128130
Ys = Y.ptr + offs_m[:, None] * Y.stride[0] + offs_n[None, :] * Y.stride[1]
129-
tl.store(Ys, y, mask=mask_y)
131+
x = tl.load(Xs, mask=_namedtuple_mask_func(X, BLOCK_M, BLOCK_N), other=0)
132+
y = closure.fn(x, *closure.captured)
133+
tl.store(Ys, y, mask=_namedtuple_mask_func(Y, BLOCK_M, BLOCK_N))
130134

131135

132136
def test_namedtuple(device="cuda"):

python/triton/compiler/code_generator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,9 @@ def _is_constexpr_global(self, name):
315315

316316
return False
317317

318+
def _is_namedtuple(self, val):
319+
return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields")
320+
318321
def _define_name_lookup(self):
319322

320323
def local_lookup(name: str, absent):
@@ -333,6 +336,7 @@ def global_lookup(name: str, absent):
333336
getattr(val, "__triton_builtin__", False), #
334337
getattr(val, "__module__", "").startswith("triton.language"), #
335338
isinstance(val, language.dtype), #
339+
self._is_namedtuple(val),
336340
self._is_constexpr_global(name), #
337341
# Allow accesses to globals while visiting an ast.arg
338342
# because you should be able to do
@@ -535,6 +539,11 @@ def assignTarget(self, target, value):
535539
def visit_Assign(self, node):
536540
# construct values to assign
537541
def _sanitize_value(value):
542+
if self._is_namedtuple(type(value)):
543+
vals = [_sanitize_value(v) for v in value]
544+
types = [v.type for v in vals]
545+
fields = type(value)._fields
546+
return language.tuple(vals, language.tuple_type(types, fields))
538547
if isinstance(value, language.tuple):
539548
return language.tuple([_sanitize_value(v) for v in value.values])
540549
native_nontensor_types = (language.dtype, language.tuple)

0 commit comments

Comments
 (0)