diff --git a/AGENTS.md b/AGENTS.md index c0921992..aee36af4 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -37,27 +37,4 @@ ## Polyhedral DSL Guidelines - Prefer using Mixin operator overloads (e.g., `A | B` instead of `A.union(B)`) for cleaner code in user scripts and DSL implementations. - -## 作業計画と進捗 (2025-11-16) -直近のギャップ集計: `docs/ISL_missing_apis.md`(2025-11-16 再生成、欠落API 2047件)。map 残 2 件(tuple_name系シンボル未提供のみ、libisl非存在)。 -優先順とステータス(✅完了 / 🚧着手中 / ⏳未着手) -- Identifier / Id: 🚧(基本APIは揃うが欠落検証 継続) -- Space / LocalSpace: 🚧(dim/tuple系以外の抜け有り) -- Constraint / Equality-Constraint / Inequality-Constraint: 🚧 -- BasicSet / Set: 🚧(missing計: set 105, basic_set 63) -- UnionSet: 🚧(missing計: union_set 52) -- BasicMap / Map: 🚧(missing計: basic_map 85, map 190) -- UnionMap: 🚧(missing計: union_map 112) -- Aff / PwAff / MultiAff / PwMultiAff: 🚧(missing計: aff 73, pw_aff 96, multi_aff 90, pw_multi_aff 89) -- MultiVal: 🚧(missing計: multi_val 37, val 66) -- MultiUnionPwAff / UnionPwAff / UnionPwMultiAff / MultiUnionPwAff: 🚧(missing計: multi_union_pw_aff 75 ほか) -- ScheduleConstraint / Schedule / ScheduleNode: ✅(schedule_node 0) -- UnionAccessInfo / UnionFlow: ⏳ -- ASTExpr / ASTNode / ASTBuild: 🚧(Expr系クラス不足・missing計: ast_expr 0, ast_node 0) -- Mat: ✅(要素参照系API実装済・missing計: mat 0) -- その他: misc 71, options 29 など多数。 - -次に着手する対象: -1) ScheduleNode / ASTExpr / Mat のクラス追加・アクセサ補完 -2) UnionAccessInfo / UnionFlow ラッパ実装 -3) map / set 系を皮切りに `docs/ISL_missing_apis.md` に基づく欠落API埋め +- Do not write shit code, be respectful to existing codes. diff --git a/caten/__init__.py b/caten/__init__.py index e69de29b..c61529e3 100644 --- a/caten/__init__.py +++ b/caten/__init__.py @@ -0,0 +1,4 @@ +from .dtype import * # noqa: F403, I001 +from .simplifier import * # noqa: F403, I001 +from .tensor import * # noqa: F403, I001 +from .runtime import cpu # noqa: I001, F401 diff --git a/caten/dtype.py b/caten/dtype.py new file mode 100644 index 00000000..9072b046 --- /dev/null +++ b/caten/dtype.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +class DTypeMetaClass(type): + dcache: dict[tuple, DType] = {} + def __call__(cls, *args: Any, **kwargs: Any) -> DType: + if (ret:=DTypeMetaClass.dcache.get(args, None)) is not None: return ret + DTypeMetaClass.dcache[args] = ret = super().__call__(*args) + return ret + +# TODO: Vector/Packed DType +@dataclass(frozen=True, eq=False) +class DType: + name: str + @staticmethod + def new(name:str) -> DType: return DType(name) + +## definitions +float64 = DType.new("float64") +float32 = DType.new("float32") +float16 = DType.new("float16") + +int64 = DType.new("int64") +int32 = DType.new("int32") +int16 = DType.new("int16") +int8 = DType.new("int8") +uint64 = DType.new("uint64") +uint32 = DType.new("uint32") +uint16 = DType.new("uint16") +uint8 = DType.new("uint8") + +## dtype aliases +index = int64 +default_float = float32 + +floats = [float64, float32, float16] +integers = [int64, int32, int16, int8, uint64, uint32, uint16, uint8] diff --git a/caten/helpers.py b/caten/helpers.py new file mode 100644 index 00000000..5cecc491 --- /dev/null +++ b/caten/helpers.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +import functools +import operator +from typing import Any, Iterable, TypeVar + +T = TypeVar("T") + +def prod(x:Iterable[T]) -> T|int: return functools.reduce(operator.mul, x, 1) + +def argfix(*x: Any) -> tuple[Any, ...]: + if x and x[0].__class__ in (tuple, list): + if len(x) != 1: raise ValueError(f"bad arg {x}") + return tuple(x[0]) + return x + +def align_left(*shapes: tuple[Any, ...]) -> tuple[tuple[Any, ...], ...]: + # unsqueeze left to make every shape same length + max_dim = max(len(shape) for shape in shapes) + return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes) diff --git a/caten/ir.py b/caten/ir.py new file mode 100644 index 00000000..e0d1dfbd --- /dev/null +++ b/caten/ir.py @@ -0,0 +1,352 @@ +from __future__ import annotations + +import dataclasses +import itertools +import math +import operator +import weakref +from dataclasses import dataclass +from typing import Any, Dict, List, Union + +from .dtype import DType, index + + +class ATenOpMetaclass(type): + cache: Dict[tuple, weakref.ReferenceType[ATenOp]] = {} + @staticmethod + def _freeze(x: Any) -> Any: + if isinstance(x, ATenOp): return x + if dataclasses.is_dataclass(x): + return (type(x),) + tuple((f.name, ATenOpMetaclass._freeze(getattr(x, f.name))) for f in dataclasses.fields(x)) + if isinstance(x, (list, tuple)): + return tuple(ATenOpMetaclass._freeze(i) for i in x) + if isinstance(x, dict): + return tuple(sorted((k, ATenOpMetaclass._freeze(v)) for k, v in x.items())) + return x + def __call__(cls, args: tuple[ATenOp, ...] | list[ATenOp], T: "ATenOpType | None" = None, **kwargs: Any) -> ATenOp: + T = cls.verify(tuple(args), T, **kwargs) # type: ignore + wret = ATenOpMetaclass.cache.get(key:=(cls, tuple(args), ATenOpMetaclass._freeze(T), ATenOpMetaclass._freeze(kwargs)), None) + if wret is not None and (ret:=wret()) is not None: return ret.simplify() + ATenOpMetaclass.cache[key] = weakref.ref(created:=super().__call__(tuple(args), T=T, **kwargs)) + return created.simplify() + +@dataclass(frozen=True) +class ATenAxis(): + size: ATenOp + stride: ATenOp + offset: ATenOp + incf: ATenOp + def index(self, i: ATenOp) -> ATenOp: + assert i.T is not None and i.T.dtype == index, "ATenAxis.index: range index should be type of index." + return Mul((self.stride, Add((Mul((i, self.incf)), self.offset)))) + +def _const(val: int, dtype: DType=index) -> ATenOp: + if isinstance(val, Const): return val + else: return Const.new(val, dtype) + +@dataclass(frozen=True) +class ATenOpType(): + axes: tuple[ATenAxis, ...] + dtype: DType + offset: Union[ATenOp, None] = None + is_ptr: bool = False # TODO: for vectorize? + def index(self, indices: List[ATenOp]) -> Any: + assert self.ndim == len(indices) + total = itertools.accumulate([b.index(a) for (a, b) in zip(indices, self.axes, strict=True)], lambda a, b: Add((a, b)), initial=Const.new(0, index)) # type: ignore + if self.offset: total = Add((total, self.offset)) # type: ignore + return total + @property + def ndim(self) -> int: return len(self.axes) + @staticmethod + def from_shape(shape: List[Any], dtype: DType) -> ATenOpType: + def _mul(a: Any, b: Any) -> Any: return Mul((_const(a), _const(b))) + strides = tuple(itertools.accumulate(reversed(shape[1:]), _mul, initial=_const(1)))[::-1] + return ATenOpType( + axes=tuple([ATenAxis(size=_const(size), stride=_const(stride), offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides, strict=True)]), # type: ignore + dtype=dtype, + ) + +@dataclass(frozen=True) +class ATenOp(metaclass=ATenOpMetaclass): + args: tuple[ATenOp, ...] + T: Union[ATenOpType, None] = None # this should be provided via T=... option, or inferred via verify method. + @property + def predecessors(self) -> tuple[ATenOp, ...]: + return tuple(self.args) + (tuple(*[tuple((axis.size, axis.stride, axis.offset, axis.incf)) for axis in self.T.axes]) + () if self.T is not None else ()) + ((self.offset,) if self.T and self.T.offset is not None else ()) # type: ignore + + @classmethod + def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs: Any) -> ATenOpType: + raise NotImplementedError("Not implemented") + + def simplify(self) -> ATenOp: + from caten.simplifier import simplifier + return simplifier.simplify(self) + + def deepwalk(self) -> None: + pass + + def viz(self) -> None: + pass + + @property + def item(self) -> Union[int, float, ATenOp]: + # Returns scalar value if self is constant folded + if isinstance(self, Const) and isinstance(self.value, (int, float)): + return self.value + else: return self + # Mixin for computing shapes (required by reshape, etc) + # TODO: Use same semantic of broadcast as tensor + def __add__(self, other: Any) -> ATenOp: return Add((self, _const(other))) + def __radd__(self, other: Any) -> ATenOp: return Add((_const(other), self)) + def __mul__(self, other: Any) -> ATenOp: return Mul((self, _const(other))) + def __rmul__(self, other: Any) -> ATenOp: return Mul((_const(other), self)) + # note: do not try to overload __eq__ since it is need to compute hash + @staticmethod + def eql(a: Union[int, float, ATenOp], b: Union[int, float, ATenOp]) -> bool: + """ + Compare two scalars (Python numbers or ATenOp scalars) for equality. + """ + if isinstance(a, (int, float)) and isinstance(b, (int, float)): return (a == b) + assert isinstance(a, ATenOp) and a.T is not None + dtype = a.T.dtype if isinstance(a, ATenOp) else b.T.dtype # A or B is asserted to have a dtype + a, b = _const(a, dtype=dtype), _const(b, dtype=dtype) # type: ignore + # Note(hikettei): this comparison highly depends on whether they are constant folded. + # plus, cannot verify the equivalence of A*B and B*A + return a == b + @staticmethod + def equals(a: tuple[Union[int, float, ATenOp], ...], b: tuple[Union[int, float, ATenOp], ...]) -> bool: + """ + Compare two lists element-wise using `ATenOp.eql` + """ + if not len(a) == len(b): return False + for ai, bi in zip(a, b, strict=True): + if not ATenOp.eql(ai, bi): return False + return True +## == Tensor Graph ============================================================ +class UnaryOps(): + # ops whose first argument is returned dtype + @classmethod + def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs: Any) -> ATenOpType: + assert len(args) == 1, f"UnaryOp {cls.__name__} takes one argument, getting {args}" + assert args[0].T is not None + return args[0].T +class BinaryOps(): + # ops whose first argument is returned dtype + @classmethod + def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs: Any) -> ATenOpType: + assert len(args) == 2, f"BinaryOp {cls.__name__} takes two argument, getting {args}" + assert args[0].T is not None + return args[0].T +class TernaryOps(): + # ops whose first argument is returned dtype + @classmethod + def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs: Any) -> ATenOpType: + assert len(args) == 3, f"TernaryOp {cls.__name__} takes three argument, getting {args}" + assert args[0].T is not None + return args[0].T +class ViewOps(): + # ops whose return dtypes are explicitly provided via T option + @classmethod + def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs: Any) -> ATenOpType: + assert T is not None, f"Cannot create {cls.__name__} without providing T" + return T +### UnaryOps +@dataclass(frozen=True) +class Neg(UnaryOps, ATenOp): + """ + OUT = -X + """ + python_op = lambda x: -x + +@dataclass(frozen=True) +class Recip(UnaryOps, ATenOp): + """ + OUT = 1/X + """ + python_op = lambda x: 1/x + +@dataclass(frozen=True) +class Sin(UnaryOps, ATenOp): + """ + OUT = sin(X) + """ + python_op = math.sin + +@dataclass(frozen=True) +class Exp2(UnaryOps, ATenOp): + """ + OUT = exp2(X) + """ + python_op = math.exp2 + +@dataclass(frozen=True) +class Log2(UnaryOps, ATenOp): + """ + OUT = log2(X) + """ + python_op = math.log2 + +@dataclass(frozen=True) +class Sqrt(UnaryOps, ATenOp): + """ + OUT = sqrt(X) + """ + python_op = math.sqrt + +@dataclass(frozen=True) +class Bitcast(UnaryOps, ATenOp): + pass + +@dataclass(frozen=True) +class Not(UnaryOps, ATenOp): + """ + Logical not if the X is a boolean + otherwise lognot ~x + """ +### BinaryOps +@dataclass(frozen=True) +class Add(BinaryOps, ATenOp): + """ + OUT = Add(X, Y) + """ + python_op = operator.add + +@dataclass(frozen=True) +class Mul(BinaryOps, ATenOp): + """ + OUT = Mul(X, Y) + """ + python_op = operator.mul + +@dataclass(frozen=True) +class IDiv(BinaryOps, ATenOp): + """ + OUT = A // B + """ + python_op = operator.floordiv + +@dataclass(frozen=True) +class And(BinaryOps, ATenOp): + pass + +@dataclass(frozen=True) +class Or(BinaryOps, ATenOp): + pass + +@dataclass(frozen=True) +class Xor(BinaryOps, ATenOp): + pass + +@dataclass(frozen=True) +class Max(BinaryOps, ATenOp): + python_op = max + +@dataclass(frozen=True) +class Mod(BinaryOps, ATenOp): + python_op = operator.mod + +@dataclass(frozen=True) +class Neq(BinaryOps, ATenOp): + python_op = operator.ne + +@dataclass(frozen=True) +class Lt(BinaryOps, ATenOp): + python_op = operator.lt +### TernaryOps +@dataclass(frozen=True) +class Where(TernaryOps, ATenOp): + python_op = lambda a, b, c: b if a else c + +### Allocation +@dataclass(frozen=True) +class Const(ViewOps, ATenOp): + value: Union[int, float, str, bool] = 0.0 + @staticmethod + def new(value: Union[int, float, str, bool], dtype: DType) -> Const: + return Const(args=(), value=value, T=ATenOpType(axes=(), dtype=dtype)) + +@dataclass(frozen=True) +class Allocate(ViewOps, ATenOp): + """ + Allocate(S1, S2, S3, ...) + """ + @staticmethod + def new(shape: List[Any], dtype: DType) -> Allocate: + return Allocate((), T=ATenOpType.from_shape(shape, dtype)) + +@dataclass(frozen=True) +class View(ViewOps, ATenOp): + """ + View(X, T=T_New) + """ + # This is the definition of view + @staticmethod + def reshape(tensor: ATenOp, shape: List[ATenOp]) -> View: + assert tensor.T is not None + return View((tensor,), T=ATenOpType.from_shape(shape, tensor.T.dtype)) + + @staticmethod + def permute(tensor: ATenOp, order: List[int]) -> View: + assert tensor.T is not None + return View((tensor,), T=ATenOpType( + axes=tuple([tensor.T.axes[i] for i in order]), + dtype=tensor.T.dtype, + offset=tensor.T.offset, + is_ptr=tensor.T.is_ptr + )) + + @staticmethod + def expand(tensor: ATenOp, shape: List[Union[int, ATenOp]]) -> View: + assert tensor.T is not None + def _expand(old_axis: ATenAxis, new_size: int | float | ATenOp) -> ATenAxis: + if ATenOp.eql(old_axis.size, new_size): return old_axis + else: + assert ATenOp.eql(old_axis.size, 1), f"The axis to expand should be evaluated to 1, getting {old_axis}" # Fix: old_axis -> old_axis.size + return ATenAxis(size=_const(new_size), stride=Const.new(0, index), offset=Const.new(0, index), incf=Const.new(1, index)) # type: ignore + return View((tensor,), T=ATenOpType( + axes=tuple([_expand(old_axis, new_size) for (old_axis, new_size) in zip(tensor.T.axes, shape, strict=True)]), + dtype=tensor.T.dtype, + offset=tensor.T.offset, + is_ptr=tensor.T.is_ptr + )) +## == JIT ===================================================================== +@dataclass(frozen=True) +class Reduce(ATenOp): + """ + OUT = Reduce(A, B, op=BinaryOps) + """ + op: type[BinaryOps] = Add + @classmethod + def from_ast_expr(cls) -> None: + pass + +@dataclass(frozen=True) +class Store(ATenOp): + pass +## ControlFlow +@dataclass(frozen=True) +class Range(ATenOp): + pass + +@dataclass(frozen=True) +class Loop(ATenOp): + pass + +@dataclass(frozen=True) +class When(ATenOp): + pass + +@dataclass(frozen=True) +class Progn(ATenOp): + pass +## == ScheduleOps ============================================================ +@dataclass(frozen=True) +class Polyhedral(ATenOp): + pass + +def Var() -> None: + pass + +# e.g.: +# a = T.Var("A[m n]", float32) +# P.stmt("...")[a] diff --git a/caten/ops.py b/caten/ops.py deleted file mode 100644 index 0e48c0bb..00000000 --- a/caten/ops.py +++ /dev/null @@ -1,7 +0,0 @@ - -class TOp: - pass - -# UOp.ADD, UOp.MUL, UOp.exp -# Pattern Matcher -# Shape diff --git a/caten/runtime/cpu.py b/caten/runtime/cpu.py new file mode 100644 index 00000000..1b544b15 --- /dev/null +++ b/caten/runtime/cpu.py @@ -0,0 +1,22 @@ +from typing import Any + +import caten as C + + +class CPUTensor(C.ATenBase): + def allocate(self) -> None: + pass + + def free(self) -> None: + pass + + #@staticmethod + def compile(self) -> None: + pass + + @staticmethod + def render(op: Any) -> None: + def _render(node: Any) -> None: + pass + +C.ATenBase.register("CPU", CPUTensor) diff --git a/caten/simplifier.py b/caten/simplifier.py new file mode 100644 index 00000000..9b0151a0 --- /dev/null +++ b/caten/simplifier.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import inspect +from dataclasses import is_dataclass, replace +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + +import caten.ir as ir + +from .ir import ATenOp + +OpType = Type[ATenOp] + +class Pat: + def __init__( + self, + op: Union[OpType, Tuple[OpType, ...], None] = None, + name: Optional[str] = None, + src: Optional[Tuple["Pat", ...]] = None, + meta: Optional[Dict[str, Callable[[ATenOp], Any]]] = None, + ): + self.op = op if (op is None or isinstance(op, tuple)) else (op,) + self.name, self.src = name, src + self.meta = meta or {} + def match(self, n: ATenOp, ctx: Dict[str, Any]) -> bool: + if self.op is not None and not isinstance(n, self.op): return False + if self.src is not None: + ks = tuple(n.args) + if len(ks) != len(self.src): return False + if not all(p.match(k, ctx) for p, k in zip(self.src, ks, strict=True)): return False + for var, ex in self.meta.items(): + v = ex(n) + if var in ctx and ctx[var] != v: return False + ctx[var] = v + if self.name: + if self.name in ctx and ctx[self.name] != n: return False + ctx[self.name] = n + return True + @staticmethod + def var(name: str) -> "Pat": + return Pat(name=name) + +class Simplifier: + def __init__(self, patterns: List[Tuple[Pat, Callable[..., Any]]]): + self.patterns = patterns + def __add__(self, other: Simplifier) -> Simplifier: + return Simplifier(self.patterns+other.patterns) + def rewrite(self, n: ATenOp, ctx_obj: Any = None) -> Optional[ATenOp]: + for pat, fn in self.patterns: + m: Dict[str, Any] = {} + if not pat.match(n, m): continue + + sig = inspect.signature(fn) + argv = [(ctx_obj if p == "ctx" else m.get(p)) for p in sig.parameters] + out = fn(*argv) + + if out is None: continue + if isinstance(out, ATenOp): return out + raise TypeError(f"rewrite returned unsupported type: {type(out)}") + return None + + def _walk_once(self, root: ATenOp, ctx_obj: Any = None) -> Tuple[ATenOp, bool]: + changed = False + memo: Dict[int, ATenOp] = {} + def go(n: ATenOp) -> ATenOp: + nonlocal changed + if (r := memo.get(id(n))) is not None: return r + ks = tuple(n.args) + if ks: + ks2 = tuple(go(k) for k in ks) + if ks2 != ks and is_dataclass(n): + n = replace(n, args=type(n.args)(ks2)) + changed = True + while True: + r2 = self.rewrite(n, ctx_obj) + if r2 is None or r2 == n: break + n = r2 + changed = True + memo[id(n)] = n + return n + return go(root), changed + + def simplify(self, root: ATenOp, ctx_obj: Any = None, max_iters: int = 30000) -> ATenOp: + cur = root + for _ in range(max_iters): + cur, ch = self._walk_once(cur, ctx_obj) + if not ch: break + return cur + +# Guard Methods +#def Guard(obj): pass +def _is_num(x: Any) -> bool: + return isinstance(x, (int, float)) and not isinstance(x, bool) + +constant_folder = Simplifier( + # UnaryOps + [( + Pat(op, src=(Pat(ir.Const, name="x"),)), # type: ignore + (lambda op: (lambda x: ir.Const.new(op.python_op(x.value), x.T.dtype) # type: ignore + if _is_num(x.value) else None))(op), + ) for op in [ir.Neg, ir.Recip, ir.Sin, ir.Exp2, ir.Log2, ir.Sqrt]] + + + # BinaryOps + [( + Pat(op, src=(Pat(ir.Const, name="a"), Pat(ir.Const, name="b"))), # type: ignore + (lambda op: (lambda a, b: ir.Const.new(op.python_op(a.value, b.value), a.T.dtype) # type: ignore + if (a.T.dtype == b.T.dtype and _is_num(a.value) and _is_num(b.value)) else None))(op), + ) for op in [ir.Add, ir.Mul, ir.IDiv, ir.Max, ir.Mod, ir.Neq, ir.Lt]] +) + + +simplifier = constant_folder diff --git a/caten/tensor.py b/caten/tensor.py new file mode 100644 index 00000000..3ea242af --- /dev/null +++ b/caten/tensor.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +import os +from abc import ABCMeta, abstractmethod +from typing import Any, Callable, List, Self, Tuple, Union + +import caten.ir as ir +from caten.helpers import align_left, argfix, prod + +from .dtype import DType, default_float, floats, index, integers + +## Backend Abstraction +DEVICE_TO_TENSOR = {} +def get_backend() -> str: return os.environ.get("BACKEND", "CPU") +## Tensor annotation for jit/aot shape check +class ATenSpec: + """ + C.Tensor[M, N] -> ATenSpec(M N) + """ + def __init__(self, shape: Tuple[Any, ...]): + self.shape: tuple[Union[int, str], ...] = shape + def __repr__(self) -> str: return f"ATenSpec{self.shape}" +## Tensor compiler core +class ATen: + op: ir.ATenOp # ATen is just a wrapper for ATenOp + @classmethod + def from_shape(cls, shape: List[ir.ATenOp], dtype: DType=default_float) -> Tensor: return Tensor(op=ir.Allocate.new(shape, dtype)) # type: ignore + @classmethod + def const(cls, obj: Any, dtype: DType=index) -> ir.Const: + match obj: + case int(): assert dtype in integers + case float(): assert dtype in floats + case _: raise TypeError(f"ATen.const: Only integer or float objects can become constant! getting {obj}") + return ir.Const.new(obj, dtype) + def forward(self, op: Callable, *args: List, **kwargs) -> Tensor: return Tensor(op=op(*args, **kwargs)) # type: ignore + def __class_getitem__(cls, item: Union[Any, Tuple[Any, ...]]) -> ATenSpec: return ATenSpec(item) + def __repr__(self) -> str: + # TODO: Display Shape, realized buffer, etc. + return f"{self.__class__.__name__}<{self.op}>" + @property + def dtype(self) -> DType: + assert self.op.T is not None + return self.op.T.dtype + @staticmethod + def wrap_const(obj: Union[ATen, ir.ATenOp, float, int], dtype: DType = index) -> ir.ATenOp: + """ + Ensures obj is a constant of dtype + """ + if isinstance(obj, ATen): + assert obj.dtype == dtype # todo: decent error msg + return obj.op + elif isinstance(obj, ir.ATenOp): + assert obj.T is not None and obj.T.dtype == dtype # todo: decent error msg + return obj + else: + return ATen.const(obj, dtype=dtype) + @staticmethod + def top(f: Callable) -> Callable: + """ + Declares the given function as toplevel tensor operation. + """ + # TODO: Toplevel in helpers.py + return f +## movement ops mixin +class ATenMovements(): + @property + def shape(self) -> tuple[ir.ATenOp, ...]: return tuple([x.size for x in self.op.T.axes]) # type: ignore + @property + def strides(self) -> tuple[ir.ATenOp, ...]: return tuple([x.stride for x in self.op.T.axes]) # type: ignore + @property + def ndim(self) -> int: return len(self.shape) + def _resolve_dim(self, dim: int, *, extra: bool = False) -> int: + total = self.ndim + int(extra) + if not -max(1, total) <= dim <= max(1, total) - 1: + raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total) - 1]}") + return dim + total if dim < 0 else dim + # ref: https://github.com/tinygrad/tinygrad/blob/master/tinygrad/mixin/movement.py#L58 + def _broadcast_to(self, new_shape: tuple[ir.ATenOp, ...]) -> Self: + """ + Implements Numpy-Semantic Broadcasting operation + """ + if ir.ATenOp.equals(self.shape, new_shape): return self + if self.ndim > len(new_shape): + raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape}") + shape, _ = align_left(self.shape, new_shape) + if not all(ir.ATenOp.eql(s, ns) or ir.ATenOp.eql(s, 1) for s, ns in zip(shape, new_shape, strict=True)): + raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}") + reshaped = self.reshape(shape) + ret = Tensor(op=ir.View.expand(self.op, new_shape)) # type: ignore + return reshaped if ir.ATenOp.equals(ret.shape, reshaped.shape) else ret # type: ignore + @ATen.top + def reshape(self, shape: tuple[Union[int, ir.ATenOp], ...], *args: Any) -> Self: + new_shape = tuple([s if s is not None else self.shape[i] for i, s in enumerate(argfix(shape, *args))]) + if (c := new_shape.count(-1)) > 1: + raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}") + if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if ir.ATenOp.eql(s, -1) else s for s in new_shape]) # type: ignore + if not ir.ATenOp.eql(prod(self.shape), prod(new_shape)): + raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})") + ret = Tensor(op=ir.View.reshape(self.op, [ATen.wrap_const(s, dtype=index) for s in new_shape])) # type: ignore + return self if ir.ATenOp.equals(ret.shape, self.shape) else ret # type: ignore + @ATen.top + def shrink(self, arg: tuple[tuple[int, int] | None, ...]) -> Self: + raise NotImplementedError("shrink todo") + @ATen.top + def permute(self, order: tuple[int, ...], *args: Any) -> Self: + order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args)) + if sorted(order_arg) != list(range(self.ndim)): + raise RuntimeError(f"order is not a valid permutation, getting {order_arg}") + return Tensor(op=ir.View.permute(self.op, order_arg)) if order_arg != tuple(range(self.ndim)) else self # type: ignore + @ATen.top + def expand(self, shape: tuple[Union[int, ir.ATenOp], ...], *args: Any) -> Self: + new_shape = tuple(from_ if ir.ATenOp.eql(to, -1) or to is None else to for from_, to in zip(*(align_left(self.shape, argfix(shape, *args))), strict=True)) + return self._broadcast_to([ATen.wrap_const(s, dtype=index) for s in new_shape]) # type: ignore + +## arithmetic mixin +class ATenArith(): + def _broadcasted(self, y:Tensor|int|float, reverse:bool=False) -> tuple[Tensor, Tensor]: + x = self + assert isinstance(x, Tensor) + if not isinstance(y, Tensor): + y = Tensor.const(y, dtype=x.dtype) # type: ignore + if x.dtype != y.dtype: # type: ignore + raise TypeError("Cannot add x and y (dtypes mismatch, todo)") + if reverse: x, y = y, x # type: ignore + # compute the output shape + def _broadcast_shape(*shapes:tuple[int|ir.ATenOp, ...]) -> tuple[int|ir.ATenOp, ...]: + def smax(a: int|ir.ATenOp, b: int|ir.ATenOp) -> int|ir.ATenOp: + if ir.ATenOp.eql(a, 1): return b + elif ir.ATenOp.eql(b, 1): return a + else: + assert ir.ATenOp.eql(a, b) + return a # a != b is asserted here? + return tuple(0 if 0 in nth_dim_sizes else smax(*nth_dim_sizes) for nth_dim_sizes in zip(*align_left(*shapes), strict=True)) + assert isinstance(x, Tensor) and isinstance(y, Tensor) + out_shape = _broadcast_shape(x.shape, y.shape) + return x._broadcast_to(out_shape), y._broadcast_to(out_shape) # type: ignore + # TODO: + # - reduce option + # - ir.Add.new (or binop) can have reduce option + @ATen.top + def add(self, other, reverse:bool=False) -> Self: return self.forward(ir.Add, tuple(self._broadcasted(self, other, reverse=reverse))) # type: ignore + @ATen.top + def mul(self, other, reverse:bool=False) -> Self: return self.forward(ir.Mul, tuple(self._broadcasted(self, other, reverse=reverse))) # type: ignore + @ATen.top + def idiv(self, other, reverse:bool=False) -> Self: return self.forward(ir.IDiv, tuple(self._broadcasted(self, other, reverse=reverse))) # type: ignore + + #def __eq__(self, other: Any): pass + # def __neq__(self, other: Any): pass + def __add__(self, other: Any) -> Self: return self.add(other) # type: ignore + def __radd__(self, other: Any) -> Self: return self.add(other, reverse=True) # type: ignore + def __mul__(self, other: Any) -> Self: return self.mul(other) # type: ignore + def __rmul__(self, other: Any) -> Self: return self.mul(other, reverse=True) # type: ignore + def __floordiv__(self, other: Any) -> Self: return self.idiv(other) # type: ignore + +## math mixin +class ATenMath(): + @ATen.top + def sin(self: ATen): return self.forward(ir.Sin, self) # type: ignore + @ATen.top + def cos(self: ATen): return self.forward(ir.Sin, self + Tensor.const(0.0, dtype=self.dtype)) # type: ignore +## nn ops mixin +class ATenNN(): + pass +## linalg ops mixin +class ATenLinalg(): + pass +## facet mixin +class Facet(): + # Facet is device transfer abstraction: A.to("CUDA") + # TODO: with tensor.facet("CUDA") as tensor: ... + pass +## abstraction over backends +class ATenBase(ATen, ATenMovements, ATenArith,ATenMath, ATenNN, ATenLinalg, metaclass=ABCMeta): + def __init__(self, *args: Any, op: Union[None, ir.ATenOp]=None): + if op is not None: self.op = op + + ## == AbstractionLayer + @staticmethod + def register(device_id: str, cls: Any) -> None: + DEVICE_TO_TENSOR[device_id] = cls + + @abstractmethod + def allocate(self) -> None: + pass + + @abstractmethod + def free(self) -> None: + pass + + @abstractmethod + def compile(self) -> None: + pass + + @staticmethod + @abstractmethod + def render(op: Any) -> None: + pass + +class Tensor(ATenBase): + def __new__(cls: Any, *args: Any, **kwargs: Any) -> Any: + impl = DEVICE_TO_TENSOR.get(get_backend()) + if impl is None: raise ValueError(f"Unknown BACKEND={get_backend()}") + return impl(*args, **kwargs) +## == [Loop-For Style Frontend IR Specs] ====================================== +def kernel(get_kernel: bool = False) -> Callable: + def decorator(func: Callable) -> Callable: + return func + return decorator +# how to generate polyhedral model from tensor ops? +# rangeify -> range/when ==> polyhedral model +# with C.range(10, 10): +# with C.when(10, 10) +class Range(): + pass + +class When(): + pass + +class LocalVar(): + pass diff --git a/examples/polyhedral_compiler.ipynb b/examples/polyhedral_compiler.ipynb index e0a3177e..35daea2f 100644 --- a/examples/polyhedral_compiler.ipynb +++ b/examples/polyhedral_compiler.ipynb @@ -640,7 +640,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 21, "id": "50683011-8c49-4b96-83ad-f588112bf769", "metadata": {}, "outputs": [ @@ -648,6 +648,7 @@ "name": "stdout", "output_type": "stream", "text": [ + "[Before Optimization]\n", "{\n", " for (int c0 = 0; c0 < N; c0 += 1)\n", " for (int c1 = 0; c1 < K_out; c1 += 1)\n", @@ -676,6 +677,7 @@ " assign(PoolBuf[c0][c1][c2][c3], max(PoolBuf[c0][c1][c2][c3], Out[c0][c1][c2 * 4 + c4][c3 * 4 + c5]));\n", "}\n", "\n", + "[After Optimization]\n", "for (int c0 = 0; c0 < N; c0 += 1)\n", " for (int c1 = 0; c1 < K_out; c1 += 1)\n", " for (int c2 = 0; c2 <= 31; c2 += 1)\n", @@ -987,10 +989,10 @@ "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 16, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -999,7 +1001,9 @@ "with (conv2d().finalize()+pool2d().finalize()).editor() as kernel:\n", " # Kernel Fusion\n", " with kernel.domain()[0] as dom:\n", + " print(\"[Before Optimization]\")\n", " print(dom.to_c())\n", + "\n", " with dom.sequence() as nk:\n", " nk[0].filter()[0].band().split(2)\n", " nk[1].filter()[0].band().split(2)\n", @@ -1015,6 +1019,7 @@ " hw.fuse()\n", " with dom.band()[0].band()[0].sequence().group(1, 3) as fused:\n", " fused[1].filter()[0].sequence().fuse()\n", + " print(\"[After Optimization]\")\n", " print(kernel.to_c())\n", " from caten.polyhedral.viz import viz_schedule\n", "viz_schedule(kernel.model.schedule.get_root())" @@ -1025,7 +1030,7 @@ "id": "f662092d-3b54-4bcb-a3b6-cda3007f79c8", "metadata": {}, "source": [ - " ## Softmax Optimization" + ") ## Softmax Optimization" ] }, { diff --git a/pyproject.toml b/pyproject.toml index a9bb6610..339fd684 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["*.ipynb"] [tool.ruff.lint] select = ["E", "F", "I", "B", "Q"] -ignore = ["E501"] +ignore = ["E501", "E701", "E401", "E731"] [tool.mypy] python_version = "3.11" diff --git a/test/test_kernel.py b/test/test_kernel.py new file mode 100644 index 00000000..70e76972 --- /dev/null +++ b/test/test_kernel.py @@ -0,0 +1,28 @@ +import caten as C + + +def test_tensor() -> None: + tensor = C.Tensor.from_shape([10, 10], dtype=C.float32) + print(tensor) + print(tensor.op.T) + print(tensor.reshape([2, 5, 10])) + +#def atest_matmul_kernel(): +# @C.kernel() +# def matmul(A: C.Tensor[N, K], B: C.Tensor[K, M]): +# Out = C.Tensor(N, M, dtype=A.dtype) +# with C.range(N) as i: +# with C.range(M) as j: +# acc = C.LocalVar(0.0) +# with C.range(K) as k: +# acc += + A[i, k] * B[k, j] +# Out[i, j] = C.tanh(acc) +# return Out + # TODO: + # 1. VMAP + # 2. Symbolic +# N = C.param("N") +# tmp = C.randn(N, 10, 10) +# a, b, c = C.randn(10, 10), C.randn(10, 10), C.randn(10, 10) +# c = matmul(a, b, c) +# tmp * c diff --git a/test/test_movements.py b/test/test_movements.py new file mode 100644 index 00000000..db8e22be --- /dev/null +++ b/test/test_movements.py @@ -0,0 +1,9 @@ + +def test_reshape() -> None: + pass + +def test_reshape_const() -> None: + pass + +def test_reshape_dynamic() -> None: + pass