|
6 | 6 | from functools import partial, wraps
|
7 | 7 | import typing
|
8 | 8 | from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple
|
| 9 | +from dataclasses import dataclass |
9 | 10 | import builtins
|
10 | 11 | from .. import knobs
|
11 |
| -from ..runtime.jit import jit |
| 12 | +from ..runtime.jit import jit, JITFunction |
12 | 13 | import inspect
|
13 | 14 |
|
14 | 15 | from .._C.libtriton import ir
|
@@ -1487,6 +1488,98 @@ def _flatten_ir(self, handles: List[ir.value]) -> None:
|
1487 | 1488 | handles.extend(s.handle for s in self.strides)
|
1488 | 1489 |
|
1489 | 1490 |
|
| 1491 | +# ----------------------- |
| 1492 | +# aggregate |
| 1493 | +# ----------------------- |
| 1494 | + |
| 1495 | + |
| 1496 | +@dataclass(frozen=True) |
| 1497 | +class _aggregate_type(base_type): |
| 1498 | + """A generic base type for all Triton aggregate types. |
| 1499 | +
|
| 1500 | + This class contains a reference to the original user-defined Python class |
| 1501 | + and a list of class fields with their Triton types. |
| 1502 | + """ |
| 1503 | + |
| 1504 | + base_cls: type |
| 1505 | + fields: List[Tuple[str, base_type]] |
| 1506 | + |
| 1507 | + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[ir.value, int]: |
| 1508 | + instance = self.base_cls._get_instance() |
| 1509 | + for name, ty in self.fields: |
| 1510 | + value, cursor = ty._unflatten_ir(handles, cursor) |
| 1511 | + setattr(instance, name, value) |
| 1512 | + return instance, cursor |
| 1513 | + |
| 1514 | + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: |
| 1515 | + for name, ty in self.fields: |
| 1516 | + ty._flatten_ir_types(builder, out) |
| 1517 | + |
| 1518 | + def mangle(self) -> str: |
| 1519 | + name = f"{self.base_cls.__module__}.{self.base_cls.__qualname__}" |
| 1520 | + fields = [ty.mangle() for (name, ty) in self.fields] |
| 1521 | + return f"{name}<{', '.join(fields)}>" |
| 1522 | + |
| 1523 | + |
| 1524 | +def _aggregate(cls): |
| 1525 | + |
| 1526 | + # Define the wrapped Triton value type. |
| 1527 | + class aggregate_value(base_value): |
| 1528 | + __triton_builtin__ = True |
| 1529 | + __triton_aggregate__ = True |
| 1530 | + |
| 1531 | + @classmethod |
| 1532 | + def _get_instance(this_cls): |
| 1533 | + return super().__new__(this_cls) |
| 1534 | + |
| 1535 | + def __new__(this_cls, *args, _builder=None, _generator=None, **kwargs): |
| 1536 | + # Call into the user-defined constructor. |
| 1537 | + instance = this_cls._get_instance() |
| 1538 | + if isinstance(cls.__init__, JITFunction): |
| 1539 | + raise ValueError(f"{cls.__name__}.__init__ cannot be a @triton.jit function") |
| 1540 | + extra_kwargs = {} |
| 1541 | + if "_builder" in inspect.signature(cls.__init__).parameters: |
| 1542 | + extra_kwargs["_builder"] = _builder |
| 1543 | + if "_generator" in inspect.signature(cls.__init__).parameters: |
| 1544 | + extra_kwargs["_generator"] = _generator |
| 1545 | + cls.__init__(instance, *args, **extra_kwargs, **kwargs) |
| 1546 | + |
| 1547 | + # Require that the user-defined constructor initialized all fields. |
| 1548 | + for name in cls.__annotations__.keys(): |
| 1549 | + if not hasattr(instance, name): |
| 1550 | + raise AttributeError(f"constructor for {cls.__name__} did not initialize attribute '{name}'") |
| 1551 | + |
| 1552 | + return instance |
| 1553 | + |
| 1554 | + # Only allow setting attributes defined in the class annotations. |
| 1555 | + def __setattr__(self, name, value): |
| 1556 | + if name not in cls.__annotations__: |
| 1557 | + raise AttributeError(f"{cls.__name__} has no attribute '{name}'") |
| 1558 | + if not isinstance(value, cls.__annotations__[name]): |
| 1559 | + raise TypeError(f"Expected {cls.__annotations__[name]} for attribute '{name}', got {type(value)}") |
| 1560 | + super().__setattr__(name, value) |
| 1561 | + |
| 1562 | + def _flatten_ir(self, handles: List[ir.value]) -> None: |
| 1563 | + for name in cls.__annotations__.keys(): |
| 1564 | + getattr(self, name)._flatten_ir(handles) |
| 1565 | + |
| 1566 | + @property |
| 1567 | + def type(self): |
| 1568 | + return _aggregate_type(aggregate_value, |
| 1569 | + [(name, getattr(self, name).type) for name in cls.__annotations__.keys()]) |
| 1570 | + |
| 1571 | + for (name, member) in inspect.getmembers(cls): |
| 1572 | + if inspect.isfunction(member) or inspect.ismethod(member) or isinstance(member, JITFunction): |
| 1573 | + if name != "__init__": |
| 1574 | + setattr(aggregate_value, name, member) |
| 1575 | + |
| 1576 | + aggregate_value.__name__ = cls.__name__ |
| 1577 | + aggregate_value.__module__ = cls.__module__ |
| 1578 | + aggregate_value.__qualname__ = cls.__qualname__ |
| 1579 | + |
| 1580 | + return aggregate_value |
| 1581 | + |
| 1582 | + |
1490 | 1583 | # -----------------------
|
1491 | 1584 | # SPMD Programming Model
|
1492 | 1585 | # -----------------------
|
|
0 commit comments