Skip to content

Commit 4c7a5f4

Browse files
authored
[Frontend] Add @tl.aggregate which autogenerates a Triton type (#6970)
📚 Stacked PRs 📚 * ➡️ triton-lang/triton#6970 * triton-lang/triton#6963 This PR adds a `@tl.aggregate` decorator which, when placed on a Python class with field annotations, automatically generates a Triton `base_type` and `base_value` based on the class. It wraps the type in a Triton type and moves all the methods over. This makes creating custom Triton types less verbose.
1 parent 3cad2e1 commit 4c7a5f4

File tree

3 files changed

+126
-38
lines changed

3 files changed

+126
-38
lines changed

python/test/unit/language/test_frontend.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import os
33
import io
44
import inspect
5-
from typing import List, Tuple
65

76
from filecheck.options import Options
87
from filecheck.finput import FInput
@@ -14,7 +13,6 @@
1413
from triton.compiler import ASTSource, make_backend
1514
from triton.backends.compiler import GPUTarget
1615
from triton._C.libtriton import ir
17-
from triton.language.core import base_type, base_value
1816

1917
import pytest
2018

@@ -113,38 +111,14 @@ def test_fn():
113111
# ===-----------------------------------------------------------------------===#
114112

115113

116-
class pair_type(base_type):
117-
118-
def __init__(self, first_type, second_type):
119-
self.first_type = first_type
120-
self.second_type = second_type
121-
122-
def __eq__(self, other) -> bool:
123-
return self.first_type == other.first_type and self.second_type == other.second_type
124-
125-
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[ir.value, int]:
126-
first, cursor = self.first_type._unflatten_ir(handles, cursor)
127-
second, cursor = self.second_type._unflatten_ir(handles, cursor)
128-
return pair_value(first, second), cursor
129-
130-
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
131-
self.first_type._flatten_ir_types(builder, out)
132-
self.second_type._flatten_ir_types(builder, out)
133-
134-
def mangle(self) -> str:
135-
return f"pair<{self.first_type.mangle()}, {self.second_type.mangle()}>"
136-
137-
138-
class pair_value(base_value):
114+
@tl.core._aggregate
115+
class Pair:
116+
first: tl.tensor
117+
second: tl.tensor
139118

140119
def __init__(self, first, second):
141120
self.first = first
142121
self.second = second
143-
self.type = pair_type(first.type, second.type)
144-
145-
def _flatten_ir(self, handles: List[ir.value]) -> None:
146-
self.first._flatten_ir(handles)
147-
self.second._flatten_ir(handles)
148122

149123
@triton.jit
150124
def get_first(self):
@@ -158,19 +132,14 @@ def unpack(self):
158132
return self.get_first(), self.get_second()
159133

160134

161-
@tl.core.builtin
162-
def pair_value_ctor(first, second, _builder=None):
163-
return pair_value(first, second)
164-
165-
166135
@filecheck_test
167136
@triton.jit
168137
def test_assign_attribute():
169138
# CHECK-LABEL: assign_attribute
170139
# CHECK: %c11_i32 = arith.constant 11 : i32
171140
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
172141
scalar = 11
173-
pair = pair_value_ctor(tl.arange(0, 4), scalar)
142+
pair = Pair(tl.arange(0, 4), scalar)
174143
# CHECK: %c42_i32 = arith.constant 42 : i32
175144
# CHECK-NEXT: call @"anchor{{.*}}"([[RANGE]], %c42_i32)
176145
pair.second = 42
@@ -185,9 +154,34 @@ def test_jit_method():
185154
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
186155
scalar = 11
187156
# CHECK: [[V:%.*]]:2 = tt.call @"unpack{{.*}}"([[RANGE]], %c11_i32)
188-
pair = pair_value_ctor(tl.arange(0, 4), scalar)
157+
pair = Pair(tl.arange(0, 4), scalar)
189158
a, b = pair.unpack()
190159
# CHECK: call @anchor{{.*}}([[V]]#0)
191160
anchor(a)
192161
# CHECK: call @anchor{{.*}}([[V]]#1)
193162
anchor(b)
163+
164+
165+
@tl.core._aggregate
166+
class TypeWithBuiltinInitializer:
167+
value: tl.tensor
168+
169+
def __init__(self, _builder=None):
170+
self.value = tl.arange(0, 4, _builder=_builder)
171+
172+
def modify(self, value, _builder=None):
173+
self.value = value
174+
175+
176+
@filecheck_test
177+
@triton.jit
178+
def test_aggregate_initializers():
179+
# CHECK-LABEL: test_aggregate_initializers
180+
value = TypeWithBuiltinInitializer()
181+
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
182+
# CHECK: call @"anchor{{.*}}"([[RANGE]])
183+
anchor(value)
184+
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 8 : i32, start = 4 : i32}
185+
# CHECK: call @"anchor{{.*}}"([[RANGE]])
186+
value.modify(tl.arange(4, 8))
187+
anchor(value)

python/triton/compiler/code_generator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ def global_lookup(name: str, absent):
372372
type(val) is ModuleType, #
373373
isinstance(val, JITFunction), #
374374
getattr(val, "__triton_builtin__", False), #
375+
getattr(val, "__triton_aggregate__", False), #
375376
getattr(val, "__module__", "").startswith("triton.language"), #
376377
isinstance(val, language.dtype), #
377378
_is_namedtuple(val),

python/triton/language/core.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from functools import partial, wraps
77
import typing
88
from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple
9+
from dataclasses import dataclass
910
import builtins
1011
from .. import knobs
11-
from ..runtime.jit import jit
12+
from ..runtime.jit import jit, JITFunction
1213
import inspect
1314

1415
from .._C.libtriton import ir
@@ -1487,6 +1488,98 @@ def _flatten_ir(self, handles: List[ir.value]) -> None:
14871488
handles.extend(s.handle for s in self.strides)
14881489

14891490

1491+
# -----------------------
1492+
# aggregate
1493+
# -----------------------
1494+
1495+
1496+
@dataclass(frozen=True)
1497+
class _aggregate_type(base_type):
1498+
"""A generic base type for all Triton aggregate types.
1499+
1500+
This class contains a reference to the original user-defined Python class
1501+
and a list of class fields with their Triton types.
1502+
"""
1503+
1504+
base_cls: type
1505+
fields: List[Tuple[str, base_type]]
1506+
1507+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[ir.value, int]:
1508+
instance = self.base_cls._get_instance()
1509+
for name, ty in self.fields:
1510+
value, cursor = ty._unflatten_ir(handles, cursor)
1511+
setattr(instance, name, value)
1512+
return instance, cursor
1513+
1514+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
1515+
for name, ty in self.fields:
1516+
ty._flatten_ir_types(builder, out)
1517+
1518+
def mangle(self) -> str:
1519+
name = f"{self.base_cls.__module__}.{self.base_cls.__qualname__}"
1520+
fields = [ty.mangle() for (name, ty) in self.fields]
1521+
return f"{name}<{', '.join(fields)}>"
1522+
1523+
1524+
def _aggregate(cls):
1525+
1526+
# Define the wrapped Triton value type.
1527+
class aggregate_value(base_value):
1528+
__triton_builtin__ = True
1529+
__triton_aggregate__ = True
1530+
1531+
@classmethod
1532+
def _get_instance(this_cls):
1533+
return super().__new__(this_cls)
1534+
1535+
def __new__(this_cls, *args, _builder=None, _generator=None, **kwargs):
1536+
# Call into the user-defined constructor.
1537+
instance = this_cls._get_instance()
1538+
if isinstance(cls.__init__, JITFunction):
1539+
raise ValueError(f"{cls.__name__}.__init__ cannot be a @triton.jit function")
1540+
extra_kwargs = {}
1541+
if "_builder" in inspect.signature(cls.__init__).parameters:
1542+
extra_kwargs["_builder"] = _builder
1543+
if "_generator" in inspect.signature(cls.__init__).parameters:
1544+
extra_kwargs["_generator"] = _generator
1545+
cls.__init__(instance, *args, **extra_kwargs, **kwargs)
1546+
1547+
# Require that the user-defined constructor initialized all fields.
1548+
for name in cls.__annotations__.keys():
1549+
if not hasattr(instance, name):
1550+
raise AttributeError(f"constructor for {cls.__name__} did not initialize attribute '{name}'")
1551+
1552+
return instance
1553+
1554+
# Only allow setting attributes defined in the class annotations.
1555+
def __setattr__(self, name, value):
1556+
if name not in cls.__annotations__:
1557+
raise AttributeError(f"{cls.__name__} has no attribute '{name}'")
1558+
if not isinstance(value, cls.__annotations__[name]):
1559+
raise TypeError(f"Expected {cls.__annotations__[name]} for attribute '{name}', got {type(value)}")
1560+
super().__setattr__(name, value)
1561+
1562+
def _flatten_ir(self, handles: List[ir.value]) -> None:
1563+
for name in cls.__annotations__.keys():
1564+
getattr(self, name)._flatten_ir(handles)
1565+
1566+
@property
1567+
def type(self):
1568+
return _aggregate_type(aggregate_value,
1569+
[(name, getattr(self, name).type) for name in cls.__annotations__.keys()])
1570+
1571+
for (name, member) in inspect.getmembers(cls):
1572+
if inspect.isfunction(member) or inspect.ismethod(member) or isinstance(member, JITFunction):
1573+
if name != "__init__":
1574+
setattr(aggregate_value, name, member)
1575+
1576+
aggregate_value.__name__ = cls.__name__
1577+
aggregate_value.__module__ = cls.__module__
1578+
aggregate_value.__qualname__ = cls.__qualname__
1579+
1580+
return aggregate_value
1581+
1582+
14901583
# -----------------------
14911584
# SPMD Programming Model
14921585
# -----------------------

0 commit comments

Comments
 (0)