From 101b88b9f6d4ec8d322825f1e0f4e0b5879620a1 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 02:21:40 +0900 Subject: [PATCH 01/53] wip --- caten/helpers.py | 0 caten/ir.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ caten/ops.py | 7 ------- caten/tensor.py | 16 ++++++++++++++++ 4 files changed, 64 insertions(+), 7 deletions(-) create mode 100644 caten/helpers.py create mode 100644 caten/ir.py create mode 100644 caten/tensor.py diff --git a/caten/helpers.py b/caten/helpers.py new file mode 100644 index 00000000..e69de29b diff --git a/caten/ir.py b/caten/ir.py new file mode 100644 index 00000000..b3410210 --- /dev/null +++ b/caten/ir.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from typing import List, Dict, Any + +@dataclass(frozen=True) +class ATenOp(metaclass=ABCMeta): + args: List[AtenOp] + @abstractmethod + @classmethod + def from_astexpr(cls): + pass +## == Tensor Graph ============================================================ +class UnaryOps(): def verify(self): verify_tensor_op(self, 1) +class BinaryOps(): def verify(self): verify_tensor_op(self, 2) +class TernaryOps(): def verify(self): verify_tensor_op(self, 3) + +class Add(ATenOp, BinaryOps): + """ + OUT = Add(X, Y) + """ + @classmethod + def from_ast_expr(cls): + pass + +class Mul(ATenOp, BinaryOps): + """ + OUT = Mul(X, Y) + """ + @classmethod + def from_ast_expr(cls): + pass +## == JIT ===================================================================== +class Reduce(ATenOp): + """ + OUT = Reduce(A, B, op=BinaryOps) + """ + op: BinaryOps + @classmethod + def from_ast_expr(cls): + pass + + +def Var(): + pass + +a = T.Var("A[m n]", float32) +P.stmt("...")[a] diff --git a/caten/ops.py b/caten/ops.py index 0e48c0bb..e69de29b 100644 --- a/caten/ops.py +++ b/caten/ops.py @@ -1,7 +0,0 @@ - -class TOp: - pass - -# UOp.ADD, UOp.MUL, UOp.exp -# Pattern Matcher -# Shape diff --git a/caten/tensor.py b/caten/tensor.py new file mode 100644 index 00000000..87ca9e15 --- /dev/null +++ b/caten/tensor.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from typing import Any, Optional, Tuple, Union +# [TODO] +# - Tensor => Fused Tensor Graph Construction +# - Tensor Kernel Construction +# - Then, start working on auto scheduler +class ATen: + # Tensor has a shape + # Tensor has a stride + # Tensor has a multi level offset + # Tensor can broadcast + # Tensor can have a computation graph + # Can lower + pass + From 9578c3733a3c0b8af0ea0d5b79390c543d018a57 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 02:58:11 +0900 Subject: [PATCH 02/53] concepts --- caten/dtype.py | 25 ++++++++++++ caten/ir.py | 106 +++++++++++++++++++++++++++++++++++++++++++++++- caten/ops.py | 0 caten/tensor.py | 64 ++++++++++++++++++++++++++++- 4 files changed, 192 insertions(+), 3 deletions(-) create mode 100644 caten/dtype.py delete mode 100644 caten/ops.py diff --git a/caten/dtype.py b/caten/dtype.py new file mode 100644 index 00000000..2638aad0 --- /dev/null +++ b/caten/dtype.py @@ -0,0 +1,25 @@ +from __future__ import annotations +from dataclasses import dataclass, fields + +class DTypeMetaClass(type): + dcache: dict[tuple, DType] = {} + def __call__(cls, *args, **kwargs): + if (ret:=DTypeMetaClass.dcache.get(args, None)) is not None: return ret + DTypeMetaClass.dcache[args] = ret = super().__call__(*args) + return ret + +# TODO: Vector/Packed DType +@dataclass(frozen=True, eq=False) +class DType: + name: str + @staticmethod + def new(name:str): return DType(name) + +## definitions +float64 = DType.new("float64") +float32 = DType.new("float32") +int64 = DType.new("int64") +int32 = DType.new("int32") + +## dtype aliases +index = int64 diff --git a/caten/ir.py b/caten/ir.py index b3410210..bdf97c6e 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -3,9 +3,18 @@ from abc import ABCMeta, abstractmethod from typing import List, Dict, Any +from .dtype import DType + +@dataclass(frozen=True) +class DTypeContext(): + shape: List[ATenOp] + stride: List[ATenOp] + dtype: DType + @dataclass(frozen=True) class ATenOp(metaclass=ABCMeta): args: List[AtenOp] + T: DTypeContext @abstractmethod @classmethod def from_astexpr(cls): @@ -14,7 +23,41 @@ def from_astexpr(cls): class UnaryOps(): def verify(self): verify_tensor_op(self, 1) class BinaryOps(): def verify(self): verify_tensor_op(self, 2) class TernaryOps(): def verify(self): verify_tensor_op(self, 3) +### UnaryOps +class Neg(ATenOp, UnaryOps): + """ + OUT = -X + """ + pass + +class Recip(ATenOp, UnaryOps): + pass + +class Sin(ATenOp, UnaryOps): + pass + +class Exp2(ATenOp, UnaryOps): + pass + +class Log2(ATenOp, UnaryOps): + pass + +class Sqrt(ATenOp, UnaryOps): + pass + +class Cast(ATenOp, UnaryOps): + pass + +class Bitcast(ATenOp, UnaryOps): + pass +class Not(ATenOp, UnaryOps): + """ + Logical not if the X is a boolean + otherwise lognot ~x + """ + pass +### BinaryOps class Add(ATenOp, BinaryOps): """ OUT = Add(X, Y) @@ -30,6 +73,47 @@ class Mul(ATenOp, BinaryOps): @classmethod def from_ast_expr(cls): pass + +class IDiv(ATenOp, BinaryOps): + pass + +class And(ATenOp, BinaryOps): + pass + +class Or(ATenOp, BinaryOps): + pass + +class And(ATenOp, BinaryOps): + pass + +class Xor(ATenOp, BinaryOps): + pass + +class Max(ATenOp, BinaryOps): + pass + +class Mod(ATenOp, BinaryOps): + pass + +class Neq(ATenOp, BinaryOps): + pass + +class Lt(ATenOp, BinaryOps): + pass +### TernaryOps +class Where(ATenOp, TernaryOps): + pass + +### Allocation +class Variable(ATenOp): + symbol: str + +class Allocate(ATenOp): + """ + Allocate(S1, S2, S3, ...) + """ + pass + ## == JIT ===================================================================== class Reduce(ATenOp): """ @@ -40,9 +124,27 @@ class Reduce(ATenOp): def from_ast_expr(cls): pass +class Store(ATenOp): + pass +## ControlFlow +class Range(ATenOp): + pass + +class Loop(ATenOp): + pass + +class When(ATenOp): + pass + +class Progn(ATenOp): + pass +## == ScheduleOps ============================================================ +class Polyhedral(ATenOp): + pass def Var(): pass -a = T.Var("A[m n]", float32) -P.stmt("...")[a] +# e.g.: +# a = T.Var("A[m n]", float32) +# P.stmt("...")[a] diff --git a/caten/ops.py b/caten/ops.py deleted file mode 100644 index e69de29b..00000000 diff --git a/caten/tensor.py b/caten/tensor.py index 87ca9e15..6d9681ce 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -1,16 +1,78 @@ from __future__ import annotations - +from abc import ABCMeta, abstractmethod from typing import Any, Optional, Tuple, Union + +import caten.ir as ir # [TODO] # - Tensor => Fused Tensor Graph Construction # - Tensor Kernel Construction # - Then, start working on auto scheduler +DEVICE_TO_TENSOR = {} + +class ATenSpec: + """ + Usage: C.Tensor[float32, "M", "N"] -> TensorSpec(M N) + """ + def __init__(self, shape: Tuple[Any, ...], dtype: Any = None): + self.shape = shape + self.dtype = dtype + def __repr__(self) -> str: return f"TensorSpec({self.shape}, {self.dtype})" + class ATen: + op: ATenOp # Just a wrapper for ATenOp + @classmethod + def from_shape(cls, shape: List[ATenOp]): + pass + + def apply(self, other: Any, func: Any) -> Tensor: + other_node = other.node if isinstance(other, Tensor) else other + res_node = func(self.node, other_node) + return ATen(node=res_node) + + def __class_getitem__(cls, item: Union[Any, Tuple[Any, ...]]) -> TensorSpec: + # Usage: C.Tensor[10, 10] -> TensorSpec((10, 10)) + pass + + def realize(self): + pass + +class ATenMath(): + def add(self, other): + pass + +class ATenMovements(): + pass + +class ATenNN(): + pass + +class ATenLinalg(): + pass + +class ATenMeta(ATen, ATenMath, ATenNN, ATenMovements, ATenLinalg, metaclass=ABCMeta): # Tensor has a shape # Tensor has a stride # Tensor has a multi level offset # Tensor can broadcast # Tensor can have a computation graph # Can lower + ## == AbstractionLayer + @abstractmethod + def allocate(self): + pass + + @abstractmethod + def free(self): + pass + +## For-Style Graph Construction +def kernel(get_kernel: bool = False) -> Callable: + def decorator(func: Callable) -> Callable: + pass + return decorator + +class Range(): pass +class When(): + pass From fec84c03e934d8eefb4a8a3efbb921eb435fae54 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 03:22:41 +0900 Subject: [PATCH 03/53] wip --- caten/__init__.py | 10 ++++++++++ caten/ir.py | 35 ++++++++++++++++++++++++--------- caten/runtime/cpu.py | 13 ++++++++++++ caten/tensor.py | 47 +++++++++++++++++++++++++++----------------- test/test_kernel.py | 4 ++++ 5 files changed, 82 insertions(+), 27 deletions(-) create mode 100644 caten/runtime/cpu.py create mode 100644 test/test_kernel.py diff --git a/caten/__init__.py b/caten/__init__.py index e69de29b..a4065749 100644 --- a/caten/__init__.py +++ b/caten/__init__.py @@ -0,0 +1,10 @@ +from . import dtype, helpers, ir, tensor +from .tensor import ATenSpec, ATen, ATenMath, ATenMovements, ATenNN, ATenLinalg, ATenBase, get_backend, Tensor +from .runtime import cpu + +__all__ = [ + "dtype", + "helpers", + "ir", + "tensor" +] diff --git a/caten/ir.py b/caten/ir.py index bdf97c6e..e4a4559e 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -2,27 +2,44 @@ from abc import ABCMeta, abstractmethod from typing import List, Dict, Any - +from dataclasses import dataclass from .dtype import DType @dataclass(frozen=True) -class DTypeContext(): - shape: List[ATenOp] - stride: List[ATenOp] +class ATenAxis(): + shape: ATenOp + stride: ATenOp + offset: ATenOp + incf: ATenOp + def index(self, i: ATenOp): + # TODO: Assert i.T.dtype is dtype.index + return Mul(self.stride, Add(Mul(i, self.incf), self.offset)) + +@dataclass(frozen=True) +class ATenOpType(): + shape: List[ATenAxis] dtype: DType + offset: ATenOp @dataclass(frozen=True) class ATenOp(metaclass=ABCMeta): args: List[AtenOp] - T: DTypeContext - @abstractmethod + T: ATenOpType @classmethod + @abstractmethod def from_astexpr(cls): pass + + @abstractmethod + def infer_dtype(self): + pass ## == Tensor Graph ============================================================ -class UnaryOps(): def verify(self): verify_tensor_op(self, 1) -class BinaryOps(): def verify(self): verify_tensor_op(self, 2) -class TernaryOps(): def verify(self): verify_tensor_op(self, 3) +class UnaryOps(): + def verify(self): verify_tensor_op(self, 1) +class BinaryOps(): + def verify(self): verify_tensor_op(self, 2) +class TernaryOps(): + def verify(self): verify_tensor_op(self, 3) ### UnaryOps class Neg(ATenOp, UnaryOps): """ diff --git a/caten/runtime/cpu.py b/caten/runtime/cpu.py new file mode 100644 index 00000000..6c06816c --- /dev/null +++ b/caten/runtime/cpu.py @@ -0,0 +1,13 @@ +import caten as C + +class CPUTensor(C.ATenBase): + def allocate(self): + pass + + def free(self): + pass + + def compile(self): + pass + +C.ATenBase.register("CPU", CPUTensor) diff --git a/caten/tensor.py b/caten/tensor.py index 6d9681ce..460986ed 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -1,13 +1,16 @@ from __future__ import annotations from abc import ABCMeta, abstractmethod -from typing import Any, Optional, Tuple, Union - +from typing import Any, Optional, Tuple, Union, ClassVar +import os import caten.ir as ir # [TODO] # - Tensor => Fused Tensor Graph Construction # - Tensor Kernel Construction # - Then, start working on auto scheduler +## Backend Abstraction DEVICE_TO_TENSOR = {} +def get_backend(): return os.environ.get("BACKEND", "CPU") +## class ATenSpec: """ @@ -19,26 +22,23 @@ def __init__(self, shape: Tuple[Any, ...], dtype: Any = None): def __repr__(self) -> str: return f"TensorSpec({self.shape}, {self.dtype})" class ATen: - op: ATenOp # Just a wrapper for ATenOp + op: ATenOp # ATen is just a wrapper for ATenOp @classmethod def from_shape(cls, shape: List[ATenOp]): - pass + return ir.Allocate(shape) # TODO - def apply(self, other: Any, func: Any) -> Tensor: - other_node = other.node if isinstance(other, Tensor) else other - res_node = func(self.node, other_node) - return ATen(node=res_node) + def apply(self, op: Callable, *args: List, **kwargs) -> ATen: return ATen(op=op(*args, **kwargs)) def __class_getitem__(cls, item: Union[Any, Tuple[Any, ...]]) -> TensorSpec: # Usage: C.Tensor[10, 10] -> TensorSpec((10, 10)) + # TODO pass - def realize(self): + def polyhedral(self): pass class ATenMath(): - def add(self, other): - pass + pass class ATenMovements(): pass @@ -49,14 +49,12 @@ class ATenNN(): class ATenLinalg(): pass -class ATenMeta(ATen, ATenMath, ATenNN, ATenMovements, ATenLinalg, metaclass=ABCMeta): - # Tensor has a shape - # Tensor has a stride - # Tensor has a multi level offset - # Tensor can broadcast - # Tensor can have a computation graph - # Can lower +class ATenBase(ATen, ATenMath, ATenNN, ATenMovements, ATenLinalg, metaclass=ABCMeta): ## == AbstractionLayer + @staticmethod + def register(device_id: str, cls: ClassVar): + DEVICE_TO_TENSOR[device_id] = cls + @abstractmethod def allocate(self): pass @@ -65,6 +63,19 @@ def allocate(self): def free(self): pass + @abstractmethod + def compile(self): + pass + +# TODO: Tensor(3, 3)ってやったら,自動でTensor = CPUTensorとかになる +class Tensor(ATenBase): + def __new__(cls, *args, **kwargs): + if cls is Tensor: + impl = DEVICE_TO_TENSOR.get(get_backend()) + if impl is None: + raise ValueError(f"Unknown BACKEND={get_backend()}") + return impl(*args, **kwargs) + return super().__new__(cls) ## For-Style Graph Construction def kernel(get_kernel: bool = False) -> Callable: def decorator(func: Callable) -> Callable: diff --git a/test/test_kernel.py b/test/test_kernel.py new file mode 100644 index 00000000..9a2807db --- /dev/null +++ b/test/test_kernel.py @@ -0,0 +1,4 @@ +import caten as C + +def test_tensor(): + print(C.Tensor()) From 7856871be75d5ee66606a2bdc697f98434030307 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 03:23:41 +0900 Subject: [PATCH 04/53] wip --- caten/tensor.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/caten/tensor.py b/caten/tensor.py index 460986ed..eec1b669 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -67,15 +67,11 @@ def free(self): def compile(self): pass -# TODO: Tensor(3, 3)ってやったら,自動でTensor = CPUTensorとかになる class Tensor(ATenBase): def __new__(cls, *args, **kwargs): - if cls is Tensor: - impl = DEVICE_TO_TENSOR.get(get_backend()) - if impl is None: - raise ValueError(f"Unknown BACKEND={get_backend()}") - return impl(*args, **kwargs) - return super().__new__(cls) + impl = DEVICE_TO_TENSOR.get(get_backend()) + if impl is None: raise ValueError(f"Unknown BACKEND={get_backend()}") + return impl(*args, **kwargs) ## For-Style Graph Construction def kernel(get_kernel: bool = False) -> Callable: def decorator(func: Callable) -> Callable: From 2d4c0ce5058a661961931fff20216aa61d7e2bc2 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 03:24:45 +0900 Subject: [PATCH 05/53] wip --- caten/tensor.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/caten/tensor.py b/caten/tensor.py index eec1b669..4452f691 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -78,8 +78,12 @@ def decorator(func: Callable) -> Callable: pass return decorator -class Range(): +# how to generate polyhedral model from tensor ops? +# rangeify -> range/when ==> polyhedral model +# with C.range(10, 10): +# with C.when(10, 10) +class range(): pass -class When(): +class when(): pass From aa350ecc2a2120c548d3e025ec58b89a9c62bb2d Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 03:28:29 +0900 Subject: [PATCH 06/53] facet ideas --- caten/tensor.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/caten/tensor.py b/caten/tensor.py index 4452f691..e1ccf8a2 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -49,6 +49,11 @@ class ATenNN(): class ATenLinalg(): pass +class Facet(): + # Facet is device transfer abstraction: A.to("CUDA") + # TODO: with tensor.facet("CUDA") as tensor: ... + pass + class ATenBase(ATen, ATenMath, ATenNN, ATenMovements, ATenLinalg, metaclass=ABCMeta): ## == AbstractionLayer @staticmethod From d800a8eb5191c8f7364f813565c5fd93552ad767 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 03:42:40 +0900 Subject: [PATCH 07/53] allocation graph --- caten/__init__.py | 3 ++- caten/ir.py | 38 ++++++++++++++++++++++++++++---------- caten/tensor.py | 5 +++-- 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/caten/__init__.py b/caten/__init__.py index a4065749..a98f46ef 100644 --- a/caten/__init__.py +++ b/caten/__init__.py @@ -1,5 +1,6 @@ from . import dtype, helpers, ir, tensor -from .tensor import ATenSpec, ATen, ATenMath, ATenMovements, ATenNN, ATenLinalg, ATenBase, get_backend, Tensor +from .dtype import * +from .tensor import * from .runtime import cpu __all__ = [ diff --git a/caten/ir.py b/caten/ir.py index e4a4559e..3e6e1ea8 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -2,8 +2,9 @@ from abc import ABCMeta, abstractmethod from typing import List, Dict, Any +import itertools from dataclasses import dataclass -from .dtype import DType +from .dtype import DType, index @dataclass(frozen=True) class ATenAxis(): @@ -19,18 +20,30 @@ def index(self, i: ATenOp): class ATenOpType(): shape: List[ATenAxis] dtype: DType - offset: ATenOp - + offset: Union[ATenOp, None] = None + @staticmethod + def from_shape(shape: List[Any], dtype: DType) -> ATenOpType: + def _const(val: int): return Const.new(val, index) + def _mul(a, b): + if not isinstance(a, Const): a = _const(a) + if not isinstance(b, Const): b = _const(b) + return Mul([a, b]) + strides = tuple(itertools.accumulate(reversed(shape[1:]), _mul, initial=_const(1)))[::-1] + return ATenOpType( + shape=[ATenAxis(shape=size, stride=stride, offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides)], + dtype=dtype, + ) + @dataclass(frozen=True) class ATenOp(metaclass=ABCMeta): args: List[AtenOp] - T: ATenOpType + T: Union[ATenOpType, None] = None @classmethod - @abstractmethod +# @abstractmethod def from_astexpr(cls): pass - @abstractmethod +# @abstractmethod def infer_dtype(self): pass ## == Tensor Graph ============================================================ @@ -122,15 +135,20 @@ class Where(ATenOp, TernaryOps): pass ### Allocation -class Variable(ATenOp): - symbol: str +class Const(ATenOp): + value: Union[int, float, str] + @staticmethod + def new(val: Union[int, float, str], dtype: DType): + return Const(val, T=ATenOpType(shape=[], dtype=dtype)) class Allocate(ATenOp): """ Allocate(S1, S2, S3, ...) """ - pass - + @staticmethod + def new(shape: List[Any], dtype: DType): + return Allocate(shape, T=ATenOpType.from_shape(shape, dtype)) + ## == JIT ===================================================================== class Reduce(ATenOp): """ diff --git a/caten/tensor.py b/caten/tensor.py index e1ccf8a2..41837fc9 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -3,6 +3,7 @@ from typing import Any, Optional, Tuple, Union, ClassVar import os import caten.ir as ir +from .dtype import float32 # [TODO] # - Tensor => Fused Tensor Graph Construction # - Tensor Kernel Construction @@ -24,8 +25,8 @@ def __repr__(self) -> str: return f"TensorSpec({self.shape}, {self.dtype})" class ATen: op: ATenOp # ATen is just a wrapper for ATenOp @classmethod - def from_shape(cls, shape: List[ATenOp]): - return ir.Allocate(shape) # TODO + def from_shape(cls, shape: List[ATenOp], dtype: DType=float32): + return ir.Allocate.new(shape, dtype) def apply(self, op: Callable, *args: List, **kwargs) -> ATen: return ATen(op=op(*args, **kwargs)) From 51de548861eec9d20bf7c5fc91679fbdd758aa49 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 03:49:53 +0900 Subject: [PATCH 08/53] allocation graph --- caten/ir.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index 3e6e1ea8..b82450be 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -21,6 +21,13 @@ class ATenOpType(): shape: List[ATenAxis] dtype: DType offset: Union[ATenOp, None] = None + def index(self, indices: List[ATenOp]): + assert self.ndim == len(indices) + total = itertools.accumlate([b.index(a) for (a, b) in zip(indices, self.shape)], lambda a, b: Add([a, b]), initial=Const.new(val, index)) + if self.offset: total = Add([total, self.offset]) + return total + @property + def ndim(self): return len(self.shape) @staticmethod def from_shape(shape: List[Any], dtype: DType) -> ATenOpType: def _const(val: int): return Const.new(val, index) @@ -41,11 +48,11 @@ class ATenOp(metaclass=ABCMeta): @classmethod # @abstractmethod def from_astexpr(cls): - pass - + pass # @abstractmethod - def infer_dtype(self): + def verify(self): pass + ## == Tensor Graph ============================================================ class UnaryOps(): def verify(self): verify_tensor_op(self, 1) From 557dc10f111b6ca67fbd47322f74b50b352765de Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 03:53:18 +0900 Subject: [PATCH 09/53] allocation graph --- caten/ir.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/caten/ir.py b/caten/ir.py index b82450be..de1a3b54 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -13,7 +13,7 @@ class ATenAxis(): offset: ATenOp incf: ATenOp def index(self, i: ATenOp): - # TODO: Assert i.T.dtype is dtype.index + assert i.T.dtype == index, "ATenAxis.index: range index should be type of index." return Mul(self.stride, Add(Mul(i, self.incf), self.offset)) @dataclass(frozen=True) @@ -45,6 +45,8 @@ def _mul(a, b): class ATenOp(metaclass=ABCMeta): args: List[AtenOp] T: Union[ATenOpType, None] = None + # TODO: Cached? + # def __init__(self, ...) @classmethod # @abstractmethod def from_astexpr(cls): @@ -52,6 +54,10 @@ def from_astexpr(cls): # @abstractmethod def verify(self): pass + + def coalese(self): + # Simplify myself + pass ## == Tensor Graph ============================================================ class UnaryOps(): From 2ab506b6e40a99aefb87f7bb777f73730a7ad3ca Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 03:57:40 +0900 Subject: [PATCH 10/53] allocation graph --- caten/ir.py | 6 ++++++ caten/tensor.py | 17 ++++++++--------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index de1a3b54..0a9941ab 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -40,6 +40,12 @@ def _mul(a, b): shape=[ATenAxis(shape=size, stride=stride, offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides)], dtype=dtype, ) + def reshape(self): + pass + def permute(self): + pass + def expand(self): + pass @dataclass(frozen=True) class ATenOp(metaclass=ABCMeta): diff --git a/caten/tensor.py b/caten/tensor.py index 41837fc9..13b294ca 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -11,8 +11,7 @@ ## Backend Abstraction DEVICE_TO_TENSOR = {} def get_backend(): return os.environ.get("BACKEND", "CPU") -## - +## Annotation class ATenSpec: """ Usage: C.Tensor[float32, "M", "N"] -> TensorSpec(M N) @@ -21,7 +20,7 @@ def __init__(self, shape: Tuple[Any, ...], dtype: Any = None): self.shape = shape self.dtype = dtype def __repr__(self) -> str: return f"TensorSpec({self.shape}, {self.dtype})" - +## Tensor datastrucure class ATen: op: ATenOp # ATen is just a wrapper for ATenOp @classmethod @@ -37,24 +36,24 @@ def __class_getitem__(cls, item: Union[Any, Tuple[Any, ...]]) -> TensorSpec: def polyhedral(self): pass - +## math mixin class ATenMath(): pass - +## movement ops mixin class ATenMovements(): pass - +## nn ops mixin class ATenNN(): pass - +## linalg ops mixin class ATenLinalg(): pass - +## facet mixin class Facet(): # Facet is device transfer abstraction: A.to("CUDA") # TODO: with tensor.facet("CUDA") as tensor: ... pass - +## abstraction over backends class ATenBase(ATen, ATenMath, ATenNN, ATenMovements, ATenLinalg, metaclass=ABCMeta): ## == AbstractionLayer @staticmethod From 673906c0802bc044e886d20c79e4481c23f8aa26 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 04:02:00 +0900 Subject: [PATCH 11/53] concrete semantics --- caten/ir.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index 0a9941ab..47b76b21 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -21,6 +21,7 @@ class ATenOpType(): shape: List[ATenAxis] dtype: DType offset: Union[ATenOp, None] = None + is_ptr: bool = False # for vectorize def index(self, indices: List[ATenOp]): assert self.ndim == len(indices) total = itertools.accumlate([b.index(a) for (a, b) in zip(indices, self.shape)], lambda a, b: Add([a, b]), initial=Const.new(val, index)) @@ -53,6 +54,10 @@ class ATenOp(metaclass=ABCMeta): T: Union[ATenOpType, None] = None # TODO: Cached? # def __init__(self, ...) + def predecessors(self): + # TODO: + # - Tに含まれるOpsをReadに含める + pass @classmethod # @abstractmethod def from_astexpr(cls): @@ -166,8 +171,15 @@ class Allocate(ATenOp): """ @staticmethod def new(shape: List[Any], dtype: DType): - return Allocate(shape, T=ATenOpType.from_shape(shape, dtype)) - + return Allocate([], T=ATenOpType.from_shape(shape, dtype)) + +class View(ATenOp): + """ + View(X, T=T_New) + """ + @staticmethod + def new(tensor: ATenOp, view: ATenOpType): + pass ## == JIT ===================================================================== class Reduce(ATenOp): """ From e4ddd5e3db5d8715f3dd6c3fe3fbebac82faeada Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 04:09:25 +0900 Subject: [PATCH 12/53] dataclass --- caten/ir.py | 39 +++++++++++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index 47b76b21..ec3a79e3 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -78,33 +78,38 @@ def verify(self): verify_tensor_op(self, 2) class TernaryOps(): def verify(self): verify_tensor_op(self, 3) ### UnaryOps +@dataclass(frozen=True) class Neg(ATenOp, UnaryOps): """ OUT = -X """ pass +@dataclass(frozen=True) class Recip(ATenOp, UnaryOps): pass +@dataclass(frozen=True) class Sin(ATenOp, UnaryOps): pass +@dataclass(frozen=True) class Exp2(ATenOp, UnaryOps): pass +@dataclass(frozen=True) class Log2(ATenOp, UnaryOps): pass +@dataclass(frozen=True) class Sqrt(ATenOp, UnaryOps): pass -class Cast(ATenOp, UnaryOps): - pass - +@dataclass(frozen=True) class Bitcast(ATenOp, UnaryOps): pass +@dataclass(frozen=True) class Not(ATenOp, UnaryOps): """ Logical not if the X is a boolean @@ -112,6 +117,7 @@ class Not(ATenOp, UnaryOps): """ pass ### BinaryOps +@dataclass(frozen=True) class Add(ATenOp, BinaryOps): """ OUT = Add(X, Y) @@ -120,6 +126,7 @@ class Add(ATenOp, BinaryOps): def from_ast_expr(cls): pass +@dataclass(frozen=True) class Mul(ATenOp, BinaryOps): """ OUT = Mul(X, Y) @@ -128,43 +135,55 @@ class Mul(ATenOp, BinaryOps): def from_ast_expr(cls): pass +@dataclass(frozen=True) class IDiv(ATenOp, BinaryOps): pass +@dataclass(frozen=True) class And(ATenOp, BinaryOps): pass +@dataclass(frozen=True) class Or(ATenOp, BinaryOps): pass +@dataclass(frozen=True) class And(ATenOp, BinaryOps): pass +@dataclass(frozen=True) class Xor(ATenOp, BinaryOps): pass +@dataclass(frozen=True) class Max(ATenOp, BinaryOps): pass +@dataclass(frozen=True) class Mod(ATenOp, BinaryOps): pass +@dataclass(frozen=True) class Neq(ATenOp, BinaryOps): pass +@dataclass(frozen=True) class Lt(ATenOp, BinaryOps): pass ### TernaryOps +@dataclass(frozen=True) class Where(ATenOp, TernaryOps): pass ### Allocation +@dataclass(frozen=True) class Const(ATenOp): - value: Union[int, float, str] + value: Union[int, float, str] = 0.0 @staticmethod - def new(val: Union[int, float, str], dtype: DType): - return Const(val, T=ATenOpType(shape=[], dtype=dtype)) + def new(value: Union[int, float, str], dtype: DType): + return Const(args=[], value=value, T=ATenOpType(shape=[], dtype=dtype)) +@dataclass(frozen=True) class Allocate(ATenOp): """ Allocate(S1, S2, S3, ...) @@ -173,6 +192,7 @@ class Allocate(ATenOp): def new(shape: List[Any], dtype: DType): return Allocate([], T=ATenOpType.from_shape(shape, dtype)) +@dataclass(frozen=True) class View(ATenOp): """ View(X, T=T_New) @@ -181,6 +201,7 @@ class View(ATenOp): def new(tensor: ATenOp, view: ATenOpType): pass ## == JIT ===================================================================== +@dataclass(frozen=True) class Reduce(ATenOp): """ OUT = Reduce(A, B, op=BinaryOps) @@ -190,21 +211,27 @@ class Reduce(ATenOp): def from_ast_expr(cls): pass +@dataclass(frozen=True) class Store(ATenOp): pass ## ControlFlow +@dataclass(frozen=True) class Range(ATenOp): pass +@dataclass(frozen=True) class Loop(ATenOp): pass +@dataclass(frozen=True) class When(ATenOp): pass +@dataclass(frozen=True) class Progn(ATenOp): pass ## == ScheduleOps ============================================================ +@dataclass(frozen=True) class Polyhedral(ATenOp): pass From 3ac5d966c3f55498edb600fcf991c45dadd4da25 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 04:10:02 +0900 Subject: [PATCH 13/53] dataclass --- caten/ir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/caten/ir.py b/caten/ir.py index ec3a79e3..240cbc7a 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -206,7 +206,7 @@ class Reduce(ATenOp): """ OUT = Reduce(A, B, op=BinaryOps) """ - op: BinaryOps + op: BinaryOps = Add @classmethod def from_ast_expr(cls): pass From e66e6268bb60a167fc0f259d4cfe383ab43b77c1 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 04:11:01 +0900 Subject: [PATCH 14/53] dataclass --- caten/ir.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index 240cbc7a..735cdd9d 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABCMeta, abstractmethod -from typing import List, Dict, Any +from typing import List, Dict, Any, Union import itertools from dataclasses import dataclass from .dtype import DType, index @@ -24,7 +24,7 @@ class ATenOpType(): is_ptr: bool = False # for vectorize def index(self, indices: List[ATenOp]): assert self.ndim == len(indices) - total = itertools.accumlate([b.index(a) for (a, b) in zip(indices, self.shape)], lambda a, b: Add([a, b]), initial=Const.new(val, index)) + total = itertools.accumlate([b.index(a) for (a, b) in zip(indices, self.shape)], lambda a, b: Add([a, b]), initial=Const.new(0, index)) if self.offset: total = Add([total, self.offset]) return total @property @@ -36,7 +36,7 @@ def _mul(a, b): if not isinstance(a, Const): a = _const(a) if not isinstance(b, Const): b = _const(b) return Mul([a, b]) - strides = tuple(itertools.accumulate(reversed(shape[1:]), _mul, initial=_const(1)))[::-1] + strides = tuple(itertools.accumlate(reversed(shape[1:]), _mul, initial=_const(1)))[::-1] return ATenOpType( shape=[ATenAxis(shape=size, stride=stride, offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides)], dtype=dtype, @@ -50,7 +50,7 @@ def expand(self): @dataclass(frozen=True) class ATenOp(metaclass=ABCMeta): - args: List[AtenOp] + args: List[ATenOp] T: Union[ATenOpType, None] = None # TODO: Cached? # def __init__(self, ...) From e4bacfd92a6ce9c00811e7a96edff69dd7fa97ac Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 04:11:29 +0900 Subject: [PATCH 15/53] dataclass --- caten/ir.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index 735cdd9d..3a4e7ecb 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -24,7 +24,7 @@ class ATenOpType(): is_ptr: bool = False # for vectorize def index(self, indices: List[ATenOp]): assert self.ndim == len(indices) - total = itertools.accumlate([b.index(a) for (a, b) in zip(indices, self.shape)], lambda a, b: Add([a, b]), initial=Const.new(0, index)) + total = itertools.accumulate([b.index(a) for (a, b) in zip(indices, self.shape)], lambda a, b: Add([a, b]), initial=Const.new(0, index)) if self.offset: total = Add([total, self.offset]) return total @property @@ -36,7 +36,7 @@ def _mul(a, b): if not isinstance(a, Const): a = _const(a) if not isinstance(b, Const): b = _const(b) return Mul([a, b]) - strides = tuple(itertools.accumlate(reversed(shape[1:]), _mul, initial=_const(1)))[::-1] + strides = tuple(itertools.accumulate(reversed(shape[1:]), _mul, initial=_const(1)))[::-1] return ATenOpType( shape=[ATenAxis(shape=size, stride=stride, offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides)], dtype=dtype, From a3056eb4364fd094403f092254d8c9e6703f351a Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 04:12:43 +0900 Subject: [PATCH 16/53] tuple --- caten/ir.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index 3a4e7ecb..3701308f 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -18,13 +18,13 @@ def index(self, i: ATenOp): @dataclass(frozen=True) class ATenOpType(): - shape: List[ATenAxis] + shape: tuple[ATenAxis] dtype: DType offset: Union[ATenOp, None] = None is_ptr: bool = False # for vectorize def index(self, indices: List[ATenOp]): assert self.ndim == len(indices) - total = itertools.accumulate([b.index(a) for (a, b) in zip(indices, self.shape)], lambda a, b: Add([a, b]), initial=Const.new(0, index)) + total = itertools.accumulate([b.index(a) for (a, b) in zip(indices, self.shape)], lambda a, b: Add((a, b)), initial=Const.new(0, index)) if self.offset: total = Add([total, self.offset]) return total @property @@ -35,7 +35,7 @@ def _const(val: int): return Const.new(val, index) def _mul(a, b): if not isinstance(a, Const): a = _const(a) if not isinstance(b, Const): b = _const(b) - return Mul([a, b]) + return Mul((a, b)) strides = tuple(itertools.accumulate(reversed(shape[1:]), _mul, initial=_const(1)))[::-1] return ATenOpType( shape=[ATenAxis(shape=size, stride=stride, offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides)], @@ -181,7 +181,7 @@ class Const(ATenOp): value: Union[int, float, str] = 0.0 @staticmethod def new(value: Union[int, float, str], dtype: DType): - return Const(args=[], value=value, T=ATenOpType(shape=[], dtype=dtype)) + return Const(args=(), value=value, T=ATenOpType(shape=[], dtype=dtype)) @dataclass(frozen=True) class Allocate(ATenOp): @@ -190,7 +190,7 @@ class Allocate(ATenOp): """ @staticmethod def new(shape: List[Any], dtype: DType): - return Allocate([], T=ATenOpType.from_shape(shape, dtype)) + return Allocate((), T=ATenOpType.from_shape(shape, dtype)) @dataclass(frozen=True) class View(ATenOp): From e2599248cdefdd9561e340f8ef09d1ffdf90971e Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 17:02:27 +0900 Subject: [PATCH 17/53] concrete semantics --- caten/ir.py | 15 +++++++++++++-- caten/tensor.py | 11 ++++++++++- test/test_kernel.py | 23 ++++++++++++++++++++++- 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index 3701308f..9d817a6d 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -41,22 +41,30 @@ def _mul(a, b): shape=[ATenAxis(shape=size, stride=stride, offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides)], dtype=dtype, ) - def reshape(self): - pass + + def reshape(self, shape: List[ATenOp]): + return View.new(x, ) + def permute(self): pass + def expand(self): pass + def cast(self): + pass + @dataclass(frozen=True) class ATenOp(metaclass=ABCMeta): args: List[ATenOp] T: Union[ATenOpType, None] = None # TODO: Cached? # def __init__(self, ...) + @property def predecessors(self): # TODO: # - Tに含まれるOpsをReadに含める + # - RangifyしたらSymbolicのDepsは消える pass @classmethod # @abstractmethod @@ -69,6 +77,9 @@ def verify(self): def coalese(self): # Simplify myself pass + + def deepwalk(self): + pass ## == Tensor Graph ============================================================ class UnaryOps(): diff --git a/caten/tensor.py b/caten/tensor.py index 13b294ca..82f0ca02 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -26,7 +26,7 @@ class ATen: @classmethod def from_shape(cls, shape: List[ATenOp], dtype: DType=float32): return ir.Allocate.new(shape, dtype) - + def apply(self, op: Callable, *args: List, **kwargs) -> ATen: return ATen(op=op(*args, **kwargs)) def __class_getitem__(cls, item: Union[Any, Tuple[Any, ...]]) -> TensorSpec: @@ -34,8 +34,17 @@ def __class_getitem__(cls, item: Union[Any, Tuple[Any, ...]]) -> TensorSpec: # TODO pass + @staticmethod + def top(): + # Register the given method as @C.sin callable + pass + def polyhedral(self): pass +## arithmetic mixin +class ATenArith(): + def __add__(self, other): + pass ## math mixin class ATenMath(): pass diff --git a/test/test_kernel.py b/test/test_kernel.py index 9a2807db..be4fff84 100644 --- a/test/test_kernel.py +++ b/test/test_kernel.py @@ -1,4 +1,25 @@ import caten as C def test_tensor(): - print(C.Tensor()) + print(C.Tensor.from_shape([10, 10], dtype=C.float32)) + +def test_matmul_kernel(): + @C.kernel() + def matmul(A: C.Tensor[N, K], B: C.Tensor[K, M]): + Out = C.Tensor(N, M, dtype=A.dtype) + with C.range(N) as i: + with C.range(M) as j: + acc = C.Const(0.0) + with C.range(K) as k: + acc += + A[i, k] * B[k, j] + Out[i, j] = C.tanh(acc) + return Out + + # TODO: + # 1. VMAP + # 2. Symbolic + N = C.param("N") + tmp = C.randn(N, 10, 10) + a, b, c = C.randn(10, 10), C.randn(10, 10), C.randn(10, 10) + c = matmul(a, b, c) + tmp * c From 227b779352e67021f716a3ed531cc80b24da28ac Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 17:02:31 +0900 Subject: [PATCH 18/53] concrete semantics --- examples/polyhedral_compiler.ipynb | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/polyhedral_compiler.ipynb b/examples/polyhedral_compiler.ipynb index e0a3177e..35daea2f 100644 --- a/examples/polyhedral_compiler.ipynb +++ b/examples/polyhedral_compiler.ipynb @@ -640,7 +640,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 21, "id": "50683011-8c49-4b96-83ad-f588112bf769", "metadata": {}, "outputs": [ @@ -648,6 +648,7 @@ "name": "stdout", "output_type": "stream", "text": [ + "[Before Optimization]\n", "{\n", " for (int c0 = 0; c0 < N; c0 += 1)\n", " for (int c1 = 0; c1 < K_out; c1 += 1)\n", @@ -676,6 +677,7 @@ " assign(PoolBuf[c0][c1][c2][c3], max(PoolBuf[c0][c1][c2][c3], Out[c0][c1][c2 * 4 + c4][c3 * 4 + c5]));\n", "}\n", "\n", + "[After Optimization]\n", "for (int c0 = 0; c0 < N; c0 += 1)\n", " for (int c1 = 0; c1 < K_out; c1 += 1)\n", " for (int c2 = 0; c2 <= 31; c2 += 1)\n", @@ -987,10 +989,10 @@ "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 16, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -999,7 +1001,9 @@ "with (conv2d().finalize()+pool2d().finalize()).editor() as kernel:\n", " # Kernel Fusion\n", " with kernel.domain()[0] as dom:\n", + " print(\"[Before Optimization]\")\n", " print(dom.to_c())\n", + "\n", " with dom.sequence() as nk:\n", " nk[0].filter()[0].band().split(2)\n", " nk[1].filter()[0].band().split(2)\n", @@ -1015,6 +1019,7 @@ " hw.fuse()\n", " with dom.band()[0].band()[0].sequence().group(1, 3) as fused:\n", " fused[1].filter()[0].sequence().fuse()\n", + " print(\"[After Optimization]\")\n", " print(kernel.to_c())\n", " from caten.polyhedral.viz import viz_schedule\n", "viz_schedule(kernel.model.schedule.get_root())" @@ -1025,7 +1030,7 @@ "id": "f662092d-3b54-4bcb-a3b6-cda3007f79c8", "metadata": {}, "source": [ - " ## Softmax Optimization" + ") ## Softmax Optimization" ] }, { From 19db18bb8441ebe5ac58cfd1528196a8b73ade16 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 17:35:36 +0900 Subject: [PATCH 19/53] concrete semantics --- caten/helpers.py | 13 ++++++++ caten/ir.py | 16 +++------- caten/tensor.py | 80 ++++++++++++++++++++++++++++-------------------- 3 files changed, 64 insertions(+), 45 deletions(-) diff --git a/caten/helpers.py b/caten/helpers.py index e69de29b..bb7bed02 100644 --- a/caten/helpers.py +++ b/caten/helpers.py @@ -0,0 +1,13 @@ +from __future__ import annotations +from typing import Iterable, TypeVar +import functools, operator + +T = TypeVar("T") + +def prod(x:Iterable[T]) -> T|int: return functools.reduce(operator.mul, x, 1) + +def argfix(*x): + if x and x[0].__class__ in (tuple, list): + if len(x) != 1: raise ValueError(f"bad arg {x}") + return tuple(x[0]) + return x diff --git a/caten/ir.py b/caten/ir.py index 9d817a6d..ae5c28ff 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -41,18 +41,6 @@ def _mul(a, b): shape=[ATenAxis(shape=size, stride=stride, offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides)], dtype=dtype, ) - - def reshape(self, shape: List[ATenOp]): - return View.new(x, ) - - def permute(self): - pass - - def expand(self): - pass - - def cast(self): - pass @dataclass(frozen=True) class ATenOp(metaclass=ABCMeta): @@ -208,6 +196,10 @@ class View(ATenOp): """ View(X, T=T_New) """ + @staticmethod + def reshape(tensor: ATenOp, shape: List[ATenOp]): + pass + @staticmethod def new(tensor: ATenOp, view: ATenOpType): pass diff --git a/caten/tensor.py b/caten/tensor.py index 82f0ca02..de501390 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -1,56 +1,67 @@ from __future__ import annotations from abc import ABCMeta, abstractmethod -from typing import Any, Optional, Tuple, Union, ClassVar +from typing import Any, Optional, Tuple, Union, ClassVar, Self import os import caten.ir as ir from .dtype import float32 -# [TODO] -# - Tensor => Fused Tensor Graph Construction -# - Tensor Kernel Construction -# - Then, start working on auto scheduler +from caten.helpers import argfix, prod ## Backend Abstraction DEVICE_TO_TENSOR = {} def get_backend(): return os.environ.get("BACKEND", "CPU") -## Annotation +## Tensor annotation for jit/aot shape check class ATenSpec: """ - Usage: C.Tensor[float32, "M", "N"] -> TensorSpec(M N) + C.Tensor[M, N] -> ATenSpec(M N) """ - def __init__(self, shape: Tuple[Any, ...], dtype: Any = None): - self.shape = shape - self.dtype = dtype - def __repr__(self) -> str: return f"TensorSpec({self.shape}, {self.dtype})" -## Tensor datastrucure + def __init__(self, shape: Tuple[Any, ...]): + self.shape: List[Union[int, str]] = shape + def __repr__(self) -> str: return f"ATenSpec{self.shape}" +## Tensor compiler core class ATen: op: ATenOp # ATen is just a wrapper for ATenOp @classmethod - def from_shape(cls, shape: List[ATenOp], dtype: DType=float32): - return ir.Allocate.new(shape, dtype) - - def apply(self, op: Callable, *args: List, **kwargs) -> ATen: return ATen(op=op(*args, **kwargs)) - - def __class_getitem__(cls, item: Union[Any, Tuple[Any, ...]]) -> TensorSpec: - # Usage: C.Tensor[10, 10] -> TensorSpec((10, 10)) - # TODO - pass - + def from_shape(cls, shape: List[ATenOp], dtype: DType=float32): return ir.Allocate.new(shape, dtype) + def forward(self, op: Callable, *args: List, **kwargs) -> ATen: return ATen(op=op(*args, **kwargs)) + def __class_getitem__(cls, item: Union[Any, Tuple[Any, ...]]) -> TensorSpec: return TensorSpec(item) + # TODO: Display Shape, realized buffer, etc. + def __repr__(self) -> str: return f"{self.__class__.__name__}<{self.node}>" @staticmethod - def top(): - # Register the given method as @C.sin callable + def top(f: Callable[Any, ATen]): + """ + Declares the given function as toplevel tensor operation. + """ + # TODO: Toplevel in helpers.py + return f + def polyhedral(self): pass - def polyhedral(self): +## movement ops mixin +class ATenMovements(): + @property + def shape(self): -> List[ATen]: pass + + @ATen.top + def reshape(self, shape, *args) -> Self: + new_shape = tuple([s if s is not None else self.shape[i] for i, s in enumerate(argfix(shape, *args))]) + if (c := new_shape.count(-1)) > 1: + raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}") + if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]) + if prod(self.shape) != prod(new_shape): + raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})") + ret = ATen(op=ir.View.reshape(self.op, new_shape)) # TODO: new_shape is ATenOp? + return self if ret.shape == self.shape else ret ## arithmetic mixin class ATenArith(): - def __add__(self, other): + @ATen.top + def add(self, other): pass ## math mixin class ATenMath(): - pass -## movement ops mixin -class ATenMovements(): - pass + @ATen.top + def sin(self: ATen): return self.forward(ir.Sin, self) + @ATen.top + def cos(self: ATen): return self.forward(ir.Sin, self + Tensor(0.0)) ## nn ops mixin class ATenNN(): pass @@ -86,7 +97,7 @@ def __new__(cls, *args, **kwargs): impl = DEVICE_TO_TENSOR.get(get_backend()) if impl is None: raise ValueError(f"Unknown BACKEND={get_backend()}") return impl(*args, **kwargs) -## For-Style Graph Construction +## == [Loop-For Style Frontend IR Specs] ====================================== def kernel(get_kernel: bool = False) -> Callable: def decorator(func: Callable) -> Callable: pass @@ -96,8 +107,11 @@ def decorator(func: Callable) -> Callable: # rangeify -> range/when ==> polyhedral model # with C.range(10, 10): # with C.when(10, 10) -class range(): +class Range(): + pass + +class When(): pass -class when(): +class LocalVar(): pass From 74471c9ad62dcba62af96e31fd53002724a5d235 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 17:44:27 +0900 Subject: [PATCH 20/53] WIP --- caten/dtype.py | 12 ++++++++++++ caten/tensor.py | 28 +++++++++++++++++++++++----- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/caten/dtype.py b/caten/dtype.py index 2638aad0..a9f7963d 100644 --- a/caten/dtype.py +++ b/caten/dtype.py @@ -18,8 +18,20 @@ def new(name:str): return DType(name) ## definitions float64 = DType.new("float64") float32 = DType.new("float32") +float16 = DType.new("float16") + int64 = DType.new("int64") int32 = DType.new("int32") +int16 = DType.new("int16") +int8 = DType.new("int8") +uint64 = DType.new("uint64") +uint32 = DType.new("uint32") +uint16 = DType.new("uint16") +uint8 = DType.new("uint8") ## dtype aliases index = int64 +default_float = float32 + +floats = [float64, float32, float16] +integers = [int64, int32, int16, int8, uint64, uint32, uint16, uint8] diff --git a/caten/tensor.py b/caten/tensor.py index de501390..39c26c3c 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -3,7 +3,7 @@ from typing import Any, Optional, Tuple, Union, ClassVar, Self import os import caten.ir as ir -from .dtype import float32 +from .dtype import default_float, index, floats, integers from caten.helpers import argfix, prod ## Backend Abstraction DEVICE_TO_TENSOR = {} @@ -20,11 +20,30 @@ def __repr__(self) -> str: return f"ATenSpec{self.shape}" class ATen: op: ATenOp # ATen is just a wrapper for ATenOp @classmethod - def from_shape(cls, shape: List[ATenOp], dtype: DType=float32): return ir.Allocate.new(shape, dtype) + def from_shape(cls, shape: List[ATenOp], dtype: DType=default_float): return ir.Allocate.new(shape, dtype) + @classmethod + def const(cls, obj: Any, dtype: DType=index): + match obj: + case int(): + assert dtype in integers + case float(): + assert dtype in floats + case _: + raise TypeError(f"ATen.const: Only integer or float objects can become constant! getting {obj}") + return ir.Const.new(obj, dtype) def forward(self, op: Callable, *args: List, **kwargs) -> ATen: return ATen(op=op(*args, **kwargs)) def __class_getitem__(cls, item: Union[Any, Tuple[Any, ...]]) -> TensorSpec: return TensorSpec(item) # TODO: Display Shape, realized buffer, etc. def __repr__(self) -> str: return f"{self.__class__.__name__}<{self.node}>" + @property + def dtype(self): return self.op.T.dtype + @staticmethod + def ensure_tensor(self, obj: Any, dtype: DType = index): + if isinstance(obj, ATen): + assert obj.dtype == dtype # todo: decent error msg + return obj + else: + return ATen.const(obj, dtype=dtype) @staticmethod def top(f: Callable[Any, ATen]): """ @@ -38,9 +57,8 @@ def polyhedral(self): ## movement ops mixin class ATenMovements(): @property - def shape(self): -> List[ATen]: + def shape(self) -> List[ATen]: pass - @ATen.top def reshape(self, shape, *args) -> Self: new_shape = tuple([s if s is not None else self.shape[i] for i, s in enumerate(argfix(shape, *args))]) @@ -49,7 +67,7 @@ def reshape(self, shape, *args) -> Self: if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]) if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})") - ret = ATen(op=ir.View.reshape(self.op, new_shape)) # TODO: new_shape is ATenOp? + ret = ATen(op=ir.View.reshape(self.op, [ATen.ensure_tensor(s, dtype=index) for s in new_shape])) # TODO: new_shape is ATenOp? return self if ret.shape == self.shape else ret ## arithmetic mixin class ATenArith(): From a5c0b823d9d3290b9a1a2fe50b94919875cfb539 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 17:49:13 +0900 Subject: [PATCH 21/53] some clean ups --- caten/ir.py | 12 ++++++------ caten/tensor.py | 15 ++++++++------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index ae5c28ff..b9288918 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -8,7 +8,7 @@ @dataclass(frozen=True) class ATenAxis(): - shape: ATenOp + size: ATenOp stride: ATenOp offset: ATenOp incf: ATenOp @@ -18,17 +18,17 @@ def index(self, i: ATenOp): @dataclass(frozen=True) class ATenOpType(): - shape: tuple[ATenAxis] + axes: tuple[ATenAxis] dtype: DType offset: Union[ATenOp, None] = None is_ptr: bool = False # for vectorize def index(self, indices: List[ATenOp]): assert self.ndim == len(indices) - total = itertools.accumulate([b.index(a) for (a, b) in zip(indices, self.shape)], lambda a, b: Add((a, b)), initial=Const.new(0, index)) + total = itertools.accumulate([b.index(a) for (a, b) in zip(indices, self.axes)], lambda a, b: Add((a, b)), initial=Const.new(0, index)) if self.offset: total = Add([total, self.offset]) return total @property - def ndim(self): return len(self.shape) + def ndim(self): return len(self.axes) @staticmethod def from_shape(shape: List[Any], dtype: DType) -> ATenOpType: def _const(val: int): return Const.new(val, index) @@ -38,7 +38,7 @@ def _mul(a, b): return Mul((a, b)) strides = tuple(itertools.accumulate(reversed(shape[1:]), _mul, initial=_const(1)))[::-1] return ATenOpType( - shape=[ATenAxis(shape=size, stride=stride, offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides)], + axes=[ATenAxis(size=size, stride=stride, offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides)], dtype=dtype, ) @@ -180,7 +180,7 @@ class Const(ATenOp): value: Union[int, float, str] = 0.0 @staticmethod def new(value: Union[int, float, str], dtype: DType): - return Const(args=(), value=value, T=ATenOpType(shape=[], dtype=dtype)) + return Const(args=(), value=value, T=ATenOpType(axes=[], dtype=dtype)) @dataclass(frozen=True) class Allocate(ATenOp): diff --git a/caten/tensor.py b/caten/tensor.py index 39c26c3c..57a96fc4 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -38,7 +38,10 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}<{self.node}>" @property def dtype(self): return self.op.T.dtype @staticmethod - def ensure_tensor(self, obj: Any, dtype: DType = index): + def wrap_const(self, obj: Any, dtype: DType = index): + """ + Ensures obj is a constant of dtype + """ if isinstance(obj, ATen): assert obj.dtype == dtype # todo: decent error msg return obj @@ -51,14 +54,12 @@ def top(f: Callable[Any, ATen]): """ # TODO: Toplevel in helpers.py return f - def polyhedral(self): - pass - ## movement ops mixin class ATenMovements(): @property - def shape(self) -> List[ATen]: - pass + def shape(self) -> List[ATen]: return [x.shape for x in self.op.T.axes] + @property + def strides(self) -> List[ATen]: return [x.stride for x in self.op.T.axes] @ATen.top def reshape(self, shape, *args) -> Self: new_shape = tuple([s if s is not None else self.shape[i] for i, s in enumerate(argfix(shape, *args))]) @@ -67,7 +68,7 @@ def reshape(self, shape, *args) -> Self: if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]) if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})") - ret = ATen(op=ir.View.reshape(self.op, [ATen.ensure_tensor(s, dtype=index) for s in new_shape])) # TODO: new_shape is ATenOp? + ret = ATen(op=ir.View.reshape(self.op, [ATen.wrap_const(s, dtype=index) for s in new_shape])) # TODO: new_shape is ATenOp? return self if ret.shape == self.shape else ret ## arithmetic mixin class ATenArith(): From 7e60cc06c9272eb484842be0b9bcc3f1700d2bc2 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 17:55:48 +0900 Subject: [PATCH 22/53] just an idea --- caten/tensor.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/caten/tensor.py b/caten/tensor.py index 57a96fc4..266de010 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -20,7 +20,7 @@ def __repr__(self) -> str: return f"ATenSpec{self.shape}" class ATen: op: ATenOp # ATen is just a wrapper for ATenOp @classmethod - def from_shape(cls, shape: List[ATenOp], dtype: DType=default_float): return ir.Allocate.new(shape, dtype) + def from_shape(cls, shape: List[ATenOp], dtype: DType=default_float): return Tensor(op=ir.Allocate.new(shape, dtype)) @classmethod def const(cls, obj: Any, dtype: DType=index): match obj: @@ -31,10 +31,10 @@ def const(cls, obj: Any, dtype: DType=index): case _: raise TypeError(f"ATen.const: Only integer or float objects can become constant! getting {obj}") return ir.Const.new(obj, dtype) - def forward(self, op: Callable, *args: List, **kwargs) -> ATen: return ATen(op=op(*args, **kwargs)) + def forward(self, op: Callable, *args: List, **kwargs) -> ATen: return Tensor(op=op(*args, **kwargs)) def __class_getitem__(cls, item: Union[Any, Tuple[Any, ...]]) -> TensorSpec: return TensorSpec(item) # TODO: Display Shape, realized buffer, etc. - def __repr__(self) -> str: return f"{self.__class__.__name__}<{self.node}>" + def __repr__(self) -> str: return f"{self.__class__.__name__}<{self.op}>" @property def dtype(self): return self.op.T.dtype @staticmethod @@ -68,7 +68,7 @@ def reshape(self, shape, *args) -> Self: if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]) if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})") - ret = ATen(op=ir.View.reshape(self.op, [ATen.wrap_const(s, dtype=index) for s in new_shape])) # TODO: new_shape is ATenOp? + ret = Tensor(op=ir.View.reshape(self.op, [ATen.wrap_const(s, dtype=index) for s in new_shape])) # TODO: new_shape is ATenOp? return self if ret.shape == self.shape else ret ## arithmetic mixin class ATenArith(): @@ -94,6 +94,9 @@ class Facet(): pass ## abstraction over backends class ATenBase(ATen, ATenMath, ATenNN, ATenMovements, ATenLinalg, metaclass=ABCMeta): + def __init__(self, *args, op=None): + self.op = op + ## == AbstractionLayer @staticmethod def register(device_id: str, cls: ClassVar): From 71fc714da4ca958e0042eb98d4e0ae5708cab57b Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 18:01:22 +0900 Subject: [PATCH 23/53] feat: reshape --- caten/ir.py | 2 +- caten/tensor.py | 6 +++--- test/test_kernel.py | 10 ++++++---- test/test_movements.py | 10 ++++++++++ 4 files changed, 20 insertions(+), 8 deletions(-) create mode 100644 test/test_movements.py diff --git a/caten/ir.py b/caten/ir.py index b9288918..e4d2307b 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -198,7 +198,7 @@ class View(ATenOp): """ @staticmethod def reshape(tensor: ATenOp, shape: List[ATenOp]): - pass + return View((tensor,), T=ATenOpType.from_shape(shape, tensor.T.dtype)) @staticmethod def new(tensor: ATenOp, view: ATenOpType): diff --git a/caten/tensor.py b/caten/tensor.py index 266de010..542eb76c 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -38,7 +38,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}<{self.op}>" @property def dtype(self): return self.op.T.dtype @staticmethod - def wrap_const(self, obj: Any, dtype: DType = index): + def wrap_const(obj: Any, dtype: DType = index): """ Ensures obj is a constant of dtype """ @@ -57,7 +57,7 @@ def top(f: Callable[Any, ATen]): ## movement ops mixin class ATenMovements(): @property - def shape(self) -> List[ATen]: return [x.shape for x in self.op.T.axes] + def shape(self) -> List[ATen]: return [x.size for x in self.op.T.axes] @property def strides(self) -> List[ATen]: return [x.stride for x in self.op.T.axes] @ATen.top @@ -80,7 +80,7 @@ class ATenMath(): @ATen.top def sin(self: ATen): return self.forward(ir.Sin, self) @ATen.top - def cos(self: ATen): return self.forward(ir.Sin, self + Tensor(0.0)) + def cos(self: ATen): return self.forward(ir.Sin, self + Tensor.const(0.0, dtype=self.dtype)) ## nn ops mixin class ATenNN(): pass diff --git a/test/test_kernel.py b/test/test_kernel.py index be4fff84..a6e8fada 100644 --- a/test/test_kernel.py +++ b/test/test_kernel.py @@ -1,20 +1,22 @@ import caten as C def test_tensor(): - print(C.Tensor.from_shape([10, 10], dtype=C.float32)) + tensor = C.Tensor.from_shape([10, 10], dtype=C.float32) + print(tensor) + print(tensor.op.T) + print(tensor.reshape([2, 5, 10])) -def test_matmul_kernel(): +def atest_matmul_kernel(): @C.kernel() def matmul(A: C.Tensor[N, K], B: C.Tensor[K, M]): Out = C.Tensor(N, M, dtype=A.dtype) with C.range(N) as i: with C.range(M) as j: - acc = C.Const(0.0) + acc = C.LocalVar(0.0) with C.range(K) as k: acc += + A[i, k] * B[k, j] Out[i, j] = C.tanh(acc) return Out - # TODO: # 1. VMAP # 2. Symbolic diff --git a/test/test_movements.py b/test/test_movements.py new file mode 100644 index 00000000..36ab322c --- /dev/null +++ b/test/test_movements.py @@ -0,0 +1,10 @@ +import pytest + +def test_reshape(): + pass + +def test_reshape_const(): + pass + +def test_reshape_dynamic(): + pass From 704b6becf2d64b17c28723e24db30029df174fed Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 18:09:26 +0900 Subject: [PATCH 24/53] feat: reshape --- caten/ir.py | 11 +++++++++++ caten/tensor.py | 32 +++++++++++++++++++++++--------- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index e4d2307b..8c7b2d23 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -41,6 +41,13 @@ def _mul(a, b): axes=[ATenAxis(size=size, stride=stride, offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides)], dtype=dtype, ) + def permute(self, order: List[int]): + return ATenOpType( + axes=[self.axes[i] for i in order], + dtype=self.dtype, + offset=self.offset, + is_ptr=self.is_ptr + ) @dataclass(frozen=True) class ATenOp(metaclass=ABCMeta): @@ -199,6 +206,10 @@ class View(ATenOp): @staticmethod def reshape(tensor: ATenOp, shape: List[ATenOp]): return View((tensor,), T=ATenOpType.from_shape(shape, tensor.T.dtype)) + + @staticmethod + def permute(tensor: ATenOp, order: List[int]): + return View((tensor,), T=tensor.T.permute(order)) @staticmethod def new(tensor: ATenOp, view: ATenOpType): diff --git a/caten/tensor.py b/caten/tensor.py index 542eb76c..45b27292 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -24,17 +24,15 @@ def from_shape(cls, shape: List[ATenOp], dtype: DType=default_float): return Ten @classmethod def const(cls, obj: Any, dtype: DType=index): match obj: - case int(): - assert dtype in integers - case float(): - assert dtype in floats - case _: - raise TypeError(f"ATen.const: Only integer or float objects can become constant! getting {obj}") + case int(): assert dtype in integers + case float(): assert dtype in floats + case _: raise TypeError(f"ATen.const: Only integer or float objects can become constant! getting {obj}") return ir.Const.new(obj, dtype) def forward(self, op: Callable, *args: List, **kwargs) -> ATen: return Tensor(op=op(*args, **kwargs)) def __class_getitem__(cls, item: Union[Any, Tuple[Any, ...]]) -> TensorSpec: return TensorSpec(item) - # TODO: Display Shape, realized buffer, etc. - def __repr__(self) -> str: return f"{self.__class__.__name__}<{self.op}>" + def __repr__(self) -> str: + # TODO: Display Shape, realized buffer, etc. + return f"{self.__class__.__name__}<{self.op}>" @property def dtype(self): return self.op.T.dtype @staticmethod @@ -60,6 +58,13 @@ class ATenMovements(): def shape(self) -> List[ATen]: return [x.size for x in self.op.T.axes] @property def strides(self) -> List[ATen]: return [x.stride for x in self.op.T.axes] + @property + def ndim(self) -> int: return len(self.shape) + def _resolve_dim(self, dim: int, *, extra: bool = False) -> int: + total = self.ndim + int(extra) + if not -max(1, total) <= dim <= max(1, total) - 1: + raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total) - 1]}") + return dim + total if dim < 0 else dim @ATen.top def reshape(self, shape, *args) -> Self: new_shape = tuple([s if s is not None else self.shape[i] for i, s in enumerate(argfix(shape, *args))]) @@ -68,8 +73,17 @@ def reshape(self, shape, *args) -> Self: if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]) if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})") - ret = Tensor(op=ir.View.reshape(self.op, [ATen.wrap_const(s, dtype=index) for s in new_shape])) # TODO: new_shape is ATenOp? + ret = Tensor(op=ir.View.reshape(self.op, [ATen.wrap_const(s, dtype=index) for s in new_shape])) return self if ret.shape == self.shape else ret + @ATen.top + def shrink(self, arg: tuple[tuple[sint, sint] | None, ...]) -> Self: + raise NotImplementedError("shrink todo") + @ATen.top + def permute(self, order, *args) -> Self: + order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args)) + if sorted(order_arg) != list(range(self.ndim)): + raise RuntimeError(f"order is not a valid permutation, getting {order_arg}") + return Tensor(op=ir.View.permute(self.op, order_arg)) if order_arg != tuple(range(self.ndim)) else self ## arithmetic mixin class ATenArith(): @ATen.top From 0b526ced1a4de803f0a15adea5e37c83cfa2b8d9 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 18:29:09 +0900 Subject: [PATCH 25/53] feat: expand --- caten/helpers.py | 5 +++++ caten/ir.py | 32 +++++++++++++++++++++----------- caten/tensor.py | 21 ++++++++++++++++++++- 3 files changed, 46 insertions(+), 12 deletions(-) diff --git a/caten/helpers.py b/caten/helpers.py index bb7bed02..1ad5a204 100644 --- a/caten/helpers.py +++ b/caten/helpers.py @@ -11,3 +11,8 @@ def argfix(*x): if len(x) != 1: raise ValueError(f"bad arg {x}") return tuple(x[0]) return x + +def align_left(*shapes): + # unsqueeze left to make every shape same length + max_dim = max(len(shape) for shape in shapes) + return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes) diff --git a/caten/ir.py b/caten/ir.py index 8c7b2d23..65f80ce2 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -41,13 +41,6 @@ def _mul(a, b): axes=[ATenAxis(size=size, stride=stride, offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides)], dtype=dtype, ) - def permute(self, order: List[int]): - return ATenOpType( - axes=[self.axes[i] for i in order], - dtype=self.dtype, - offset=self.offset, - is_ptr=self.is_ptr - ) @dataclass(frozen=True) class ATenOp(metaclass=ABCMeta): @@ -203,17 +196,34 @@ class View(ATenOp): """ View(X, T=T_New) """ + # This is the definition of view @staticmethod def reshape(tensor: ATenOp, shape: List[ATenOp]): return View((tensor,), T=ATenOpType.from_shape(shape, tensor.T.dtype)) @staticmethod def permute(tensor: ATenOp, order: List[int]): - return View((tensor,), T=tensor.T.permute(order)) - + return View((tensor,), T=ATenOpType( + axes=[tensor.T.axes[i] for i in order], + dtype=tensor.T.dtype, + offset=tensor.T.offset, + is_ptr=tensor.T.is_ptr + )) + @staticmethod - def new(tensor: ATenOp, view: ATenOpType): - pass + def expand(tensor: ATenOp, shape: List[Union[int, ATenOp]]): + def _expand(old_axis: ATenAxis, new_size: ATenOp) -> ATenAxis: + if old_axis.size == new_size: return old_axis + else: + assert old_axis == -1 + return ATenAxis(size=new_size, stride=Const.new(0, index), offset=Const.new(0, index), incf=Const.new(1, index)) + return View((tensor,), T=ATenOpType( + axes=[_expand(old_axis, new_size) for (old_axis, new_size) in zip(tensor.T.axes, shape)] + dtype=tensor.T.dtype, + offset=tensor.T.offset, + is_ptr=tensor.T.is_ptr + ) + ## == JIT ===================================================================== @dataclass(frozen=True) class Reduce(ATenOp): diff --git a/caten/tensor.py b/caten/tensor.py index 45b27292..7b849dd8 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -4,7 +4,7 @@ import os import caten.ir as ir from .dtype import default_float, index, floats, integers -from caten.helpers import argfix, prod +from caten.helpers import argfix, prod, align_left ## Backend Abstraction DEVICE_TO_TENSOR = {} def get_backend(): return os.environ.get("BACKEND", "CPU") @@ -65,6 +65,20 @@ def _resolve_dim(self, dim: int, *, extra: bool = False) -> int: if not -max(1, total) <= dim <= max(1, total) - 1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total) - 1]}") return dim + total if dim < 0 else dim + # ref: https://github.com/tinygrad/tinygrad/blob/master/tinygrad/mixin/movement.py#L58 + def _broadcast_to(self, new_shape: List[ATen]) -> Self: + """ + Implements Numpy-Semantic Broadcasting operation + """ + if self.shape == new_shape: return self + if self.ndim > len(new_shape): + raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape=}") + shape, _ = align_left(self.shape, new_shape) + if not all(s == ns or s == 1 for s, ns in zip(shape, new_shape)): + raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}") + reshaped = self.reshape(shape) + ret = Tensor(op=ir.View.expand(self.op, new_shape)) + return reshaped if ret.shape == reshaped.shape else ret @ATen.top def reshape(self, shape, *args) -> Self: new_shape = tuple([s if s is not None else self.shape[i] for i, s in enumerate(argfix(shape, *args))]) @@ -84,11 +98,16 @@ def permute(self, order, *args) -> Self: if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}") return Tensor(op=ir.View.permute(self.op, order_arg)) if order_arg != tuple(range(self.ndim)) else self + @ATen.top + def expand(self, shape, *args) -> Self: + new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(align_left(self.shape, argfix(shape, *args))))) + return self._broadcast_to([ATen.wrap_const(s, dtype=index) for s in new_shape]) ## arithmetic mixin class ATenArith(): @ATen.top def add(self, other): pass + # TODO: self == 1 is evalued to true if self is const ## math mixin class ATenMath(): @ATen.top From 62de2795c97f2b750480944089f4be40feb848b8 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 18:29:39 +0900 Subject: [PATCH 26/53] feat: expand --- caten/ir.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index 65f80ce2..7da5900b 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -218,11 +218,11 @@ def _expand(old_axis: ATenAxis, new_size: ATenOp) -> ATenAxis: assert old_axis == -1 return ATenAxis(size=new_size, stride=Const.new(0, index), offset=Const.new(0, index), incf=Const.new(1, index)) return View((tensor,), T=ATenOpType( - axes=[_expand(old_axis, new_size) for (old_axis, new_size) in zip(tensor.T.axes, shape)] + axes=[_expand(old_axis, new_size) for (old_axis, new_size) in zip(tensor.T.axes, shape)], dtype=tensor.T.dtype, offset=tensor.T.offset, is_ptr=tensor.T.is_ptr - ) + )) ## == JIT ===================================================================== @dataclass(frozen=True) From f679b5025b7f68b3d3d90b7aabdda0fc1ecfcccc Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 19:17:42 +0900 Subject: [PATCH 27/53] Cache, Type Inference, Verification --- caten/ir.py | 123 +++++++++++++++++++++++++++++++--------------------- 1 file changed, 74 insertions(+), 49 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index 7da5900b..dffc87be 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -2,10 +2,30 @@ from abc import ABCMeta, abstractmethod from typing import List, Dict, Any, Union -import itertools +import itertools, weakref, dataclasses from dataclasses import dataclass from .dtype import DType, index +class ATenOpMetaclass(type): + cache: Dict[tuple, weakref.ReferenceType[ATenOp]] = {} + @staticmethod + def _freeze(x: Any) -> Any: + if isinstance(x, ATenOp): return x + if dataclasses.is_dataclass(x): + return (type(x),) + tuple((f.name, ATenOpMetaclass._freeze(getattr(x, f.name))) for f in dataclasses.fields(x)) + if isinstance(x, (list, tuple)): + return tuple(ATenOpMetaclass._freeze(i) for i in x) + if isinstance(x, dict): + return tuple(sorted((k, ATenOpMetaclass._freeze(v)) for k, v in x.items())) + return x + def __call__(cls, args: tuple[ATenOp, ...], T: "ATenOpType | None" = None, **kwargs): + T = cls.verify(args, T, **kwargs) # type inference + wret = ATenOpMetaclass.cache.get(key:=(cls, tuple(args), ATenOpMetaclass._freeze(T), ATenOpMetaclass._freeze(kwargs)), None) + if wret is not None and (ret:=wret()) is not None: + return ret + ATenOpMetaclass.cache[key] = weakref.ref(created:=super().__call__(args, T=T, **kwargs)) + return created + @dataclass(frozen=True) class ATenAxis(): size: ATenOp @@ -38,29 +58,24 @@ def _mul(a, b): return Mul((a, b)) strides = tuple(itertools.accumulate(reversed(shape[1:]), _mul, initial=_const(1)))[::-1] return ATenOpType( - axes=[ATenAxis(size=size, stride=stride, offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides)], + axes=tuple([ATenAxis(size=size, stride=stride, offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides)]), dtype=dtype, ) @dataclass(frozen=True) -class ATenOp(metaclass=ABCMeta): +class ATenOp(metaclass=ATenOpMetaclass): args: List[ATenOp] - T: Union[ATenOpType, None] = None - # TODO: Cached? - # def __init__(self, ...) + T: Union[ATenOpType, None] = None @property def predecessors(self): # TODO: # - Tに含まれるOpsをReadに含める # - RangifyしたらSymbolicのDepsは消える pass + @classmethod -# @abstractmethod - def from_astexpr(cls): - pass -# @abstractmethod - def verify(self): - pass + def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs) -> ATenOpType: + raise NotImplementedError("Not implemented") def coalese(self): # Simplify myself @@ -71,45 +86,63 @@ def deepwalk(self): ## == Tensor Graph ============================================================ class UnaryOps(): - def verify(self): verify_tensor_op(self, 1) + # ops whose first argument is returned dtype + @classmethod + def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs) -> ATenOpType: + assert len(args) == 1 + return args[0].T class BinaryOps(): - def verify(self): verify_tensor_op(self, 2) + # ops whose first argument is returned dtype + @classmethod + def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs) -> ATenOpType: + assert len(args) == 2 + return args[0].T class TernaryOps(): - def verify(self): verify_tensor_op(self, 3) + # ops whose first argument is returned dtype + @classmethod + def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs) -> ATenOpType: + assert len(args) == 3 + return args[0].T +class ViewOps(): + # ops whose return dtypes are explicitly provided via T option + @classmethod + def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs) -> ATenOpType: + assert T is not None, f"Cannot create {cls.__name__} without providing T" + return T ### UnaryOps @dataclass(frozen=True) -class Neg(ATenOp, UnaryOps): +class Neg(UnaryOps, ATenOp): """ OUT = -X """ pass @dataclass(frozen=True) -class Recip(ATenOp, UnaryOps): +class Recip(UnaryOps, ATenOp): pass @dataclass(frozen=True) -class Sin(ATenOp, UnaryOps): +class Sin(UnaryOps, ATenOp): pass @dataclass(frozen=True) -class Exp2(ATenOp, UnaryOps): +class Exp2(UnaryOps, ATenOp): pass @dataclass(frozen=True) -class Log2(ATenOp, UnaryOps): +class Log2(UnaryOps, ATenOp): pass @dataclass(frozen=True) -class Sqrt(ATenOp, UnaryOps): +class Sqrt(UnaryOps, ATenOp): pass @dataclass(frozen=True) -class Bitcast(ATenOp, UnaryOps): +class Bitcast(UnaryOps, ATenOp): pass @dataclass(frozen=True) -class Not(ATenOp, UnaryOps): +class Not(UnaryOps, ATenOp): """ Logical not if the X is a boolean otherwise lognot ~x @@ -117,73 +150,66 @@ class Not(ATenOp, UnaryOps): pass ### BinaryOps @dataclass(frozen=True) -class Add(ATenOp, BinaryOps): +class Add(BinaryOps, ATenOp): """ OUT = Add(X, Y) """ - @classmethod - def from_ast_expr(cls): - pass @dataclass(frozen=True) -class Mul(ATenOp, BinaryOps): +class Mul(BinaryOps, ATenOp): """ OUT = Mul(X, Y) """ - @classmethod - def from_ast_expr(cls): - pass - @dataclass(frozen=True) -class IDiv(ATenOp, BinaryOps): +class IDiv(BinaryOps, ATenOp): pass @dataclass(frozen=True) -class And(ATenOp, BinaryOps): +class And(BinaryOps, ATenOp): pass @dataclass(frozen=True) -class Or(ATenOp, BinaryOps): +class Or(BinaryOps, ATenOp): pass @dataclass(frozen=True) -class And(ATenOp, BinaryOps): +class And(BinaryOps, ATenOp): pass @dataclass(frozen=True) -class Xor(ATenOp, BinaryOps): +class Xor(BinaryOps, ATenOp): pass @dataclass(frozen=True) -class Max(ATenOp, BinaryOps): +class Max(BinaryOps, ATenOp): pass @dataclass(frozen=True) -class Mod(ATenOp, BinaryOps): +class Mod(BinaryOps, ATenOp): pass @dataclass(frozen=True) -class Neq(ATenOp, BinaryOps): +class Neq(BinaryOps, ATenOp): pass @dataclass(frozen=True) -class Lt(ATenOp, BinaryOps): +class Lt(BinaryOps, ATenOp): pass ### TernaryOps @dataclass(frozen=True) -class Where(ATenOp, TernaryOps): +class Where(TernaryOps, ATenOp): pass ### Allocation @dataclass(frozen=True) -class Const(ATenOp): +class Const(ViewOps, ATenOp): value: Union[int, float, str] = 0.0 @staticmethod def new(value: Union[int, float, str], dtype: DType): - return Const(args=(), value=value, T=ATenOpType(axes=[], dtype=dtype)) + return Const(args=(), value=value, T=ATenOpType(axes=(), dtype=dtype)) @dataclass(frozen=True) -class Allocate(ATenOp): +class Allocate(ViewOps, ATenOp): """ Allocate(S1, S2, S3, ...) """ @@ -192,7 +218,7 @@ def new(shape: List[Any], dtype: DType): return Allocate((), T=ATenOpType.from_shape(shape, dtype)) @dataclass(frozen=True) -class View(ATenOp): +class View(ViewOps, ATenOp): """ View(X, T=T_New) """ @@ -204,7 +230,7 @@ def reshape(tensor: ATenOp, shape: List[ATenOp]): @staticmethod def permute(tensor: ATenOp, order: List[int]): return View((tensor,), T=ATenOpType( - axes=[tensor.T.axes[i] for i in order], + axes=tuple([tensor.T.axes[i] for i in order]), dtype=tensor.T.dtype, offset=tensor.T.offset, is_ptr=tensor.T.is_ptr @@ -218,12 +244,11 @@ def _expand(old_axis: ATenAxis, new_size: ATenOp) -> ATenAxis: assert old_axis == -1 return ATenAxis(size=new_size, stride=Const.new(0, index), offset=Const.new(0, index), incf=Const.new(1, index)) return View((tensor,), T=ATenOpType( - axes=[_expand(old_axis, new_size) for (old_axis, new_size) in zip(tensor.T.axes, shape)], + axes=tuple([_expand(old_axis, new_size) for (old_axis, new_size) in zip(tensor.T.axes, shape)]), dtype=tensor.T.dtype, offset=tensor.T.offset, is_ptr=tensor.T.is_ptr )) - ## == JIT ===================================================================== @dataclass(frozen=True) class Reduce(ATenOp): From bc0f4a1c51ffe405804b4c3bae7f72e951533087 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 19:19:39 +0900 Subject: [PATCH 28/53] Cache, Type Inference, Verification --- caten/ir.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index dffc87be..22d63875 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -19,10 +19,9 @@ def _freeze(x: Any) -> Any: return tuple(sorted((k, ATenOpMetaclass._freeze(v)) for k, v in x.items())) return x def __call__(cls, args: tuple[ATenOp, ...], T: "ATenOpType | None" = None, **kwargs): - T = cls.verify(args, T, **kwargs) # type inference + T = cls.verify(args, T, **kwargs) # run type inference+verification wret = ATenOpMetaclass.cache.get(key:=(cls, tuple(args), ATenOpMetaclass._freeze(T), ATenOpMetaclass._freeze(kwargs)), None) - if wret is not None and (ret:=wret()) is not None: - return ret + if wret is not None and (ret:=wret()) is not None: return ret ATenOpMetaclass.cache[key] = weakref.ref(created:=super().__call__(args, T=T, **kwargs)) return created @@ -41,7 +40,7 @@ class ATenOpType(): axes: tuple[ATenAxis] dtype: DType offset: Union[ATenOp, None] = None - is_ptr: bool = False # for vectorize + is_ptr: bool = False # TODO: for vectorize? def index(self, indices: List[ATenOp]): assert self.ndim == len(indices) total = itertools.accumulate([b.index(a) for (a, b) in zip(indices, self.axes)], lambda a, b: Add((a, b)), initial=Const.new(0, index)) @@ -89,19 +88,19 @@ class UnaryOps(): # ops whose first argument is returned dtype @classmethod def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs) -> ATenOpType: - assert len(args) == 1 + assert len(args) == 1, f"UnaryOp {cls.__name__} takes one argument, getting {args}" return args[0].T class BinaryOps(): # ops whose first argument is returned dtype @classmethod def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs) -> ATenOpType: - assert len(args) == 2 + assert len(args) == 2, f"BinaryOp {cls.__name__} takes two argument, getting {args}" return args[0].T class TernaryOps(): # ops whose first argument is returned dtype @classmethod def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs) -> ATenOpType: - assert len(args) == 3 + assert len(args) == 3, f"TernaryOp {cls.__name__} takes three argument, getting {args}" return args[0].T class ViewOps(): # ops whose return dtypes are explicitly provided via T option From 74705513de9b3c16f33cb3eed24352ae854cd122 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 19:20:06 +0900 Subject: [PATCH 29/53] Cache, Type Inference, Verification --- caten/ir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/caten/ir.py b/caten/ir.py index 22d63875..ec4169e7 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -64,7 +64,7 @@ def _mul(a, b): @dataclass(frozen=True) class ATenOp(metaclass=ATenOpMetaclass): args: List[ATenOp] - T: Union[ATenOpType, None] = None + T: Union[ATenOpType, None] = None # this should be provided via T=... option, or inferred via verify method. @property def predecessors(self): # TODO: From d66007cb2361b8e0e00fc6568edddbcac53d7775 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 19:25:16 +0900 Subject: [PATCH 30/53] Cache, Type Inference, Verification --- AGENTS.md | 25 +------------------------ caten/ir.py | 7 ++++++- 2 files changed, 7 insertions(+), 25 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index c0921992..aee36af4 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -37,27 +37,4 @@ ## Polyhedral DSL Guidelines - Prefer using Mixin operator overloads (e.g., `A | B` instead of `A.union(B)`) for cleaner code in user scripts and DSL implementations. - -## 作業計画と進捗 (2025-11-16) -直近のギャップ集計: `docs/ISL_missing_apis.md`(2025-11-16 再生成、欠落API 2047件)。map 残 2 件(tuple_name系シンボル未提供のみ、libisl非存在)。 -優先順とステータス(✅完了 / 🚧着手中 / ⏳未着手) -- Identifier / Id: 🚧(基本APIは揃うが欠落検証 継続) -- Space / LocalSpace: 🚧(dim/tuple系以外の抜け有り) -- Constraint / Equality-Constraint / Inequality-Constraint: 🚧 -- BasicSet / Set: 🚧(missing計: set 105, basic_set 63) -- UnionSet: 🚧(missing計: union_set 52) -- BasicMap / Map: 🚧(missing計: basic_map 85, map 190) -- UnionMap: 🚧(missing計: union_map 112) -- Aff / PwAff / MultiAff / PwMultiAff: 🚧(missing計: aff 73, pw_aff 96, multi_aff 90, pw_multi_aff 89) -- MultiVal: 🚧(missing計: multi_val 37, val 66) -- MultiUnionPwAff / UnionPwAff / UnionPwMultiAff / MultiUnionPwAff: 🚧(missing計: multi_union_pw_aff 75 ほか) -- ScheduleConstraint / Schedule / ScheduleNode: ✅(schedule_node 0) -- UnionAccessInfo / UnionFlow: ⏳ -- ASTExpr / ASTNode / ASTBuild: 🚧(Expr系クラス不足・missing計: ast_expr 0, ast_node 0) -- Mat: ✅(要素参照系API実装済・missing計: mat 0) -- その他: misc 71, options 29 など多数。 - -次に着手する対象: -1) ScheduleNode / ASTExpr / Mat のクラス追加・アクセサ補完 -2) UnionAccessInfo / UnionFlow ラッパ実装 -3) map / set 系を皮切りに `docs/ISL_missing_apis.md` に基づく欠落API埋め +- Do not write shit code, be respectful to existing codes. diff --git a/caten/ir.py b/caten/ir.py index ec4169e7..bfbf8e2b 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -82,7 +82,9 @@ def coalese(self): def deepwalk(self): pass - + + def viz(self): + pass ## == Tensor Graph ============================================================ class UnaryOps(): # ops whose first argument is returned dtype @@ -153,12 +155,15 @@ class Add(BinaryOps, ATenOp): """ OUT = Add(X, Y) """ + pass @dataclass(frozen=True) class Mul(BinaryOps, ATenOp): """ OUT = Mul(X, Y) """ + pass + @dataclass(frozen=True) class IDiv(BinaryOps, ATenOp): pass From 8d681928176476724bcfa55140ad051ef07a3870 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 19:47:41 +0900 Subject: [PATCH 31/53] Feat: Simplifier --- caten/ir.py | 7 ++-- caten/simplifier.py | 84 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 5 deletions(-) create mode 100644 caten/simplifier.py diff --git a/caten/ir.py b/caten/ir.py index bfbf8e2b..e0c64881 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -66,11 +66,8 @@ class ATenOp(metaclass=ATenOpMetaclass): args: List[ATenOp] T: Union[ATenOpType, None] = None # this should be provided via T=... option, or inferred via verify method. @property - def predecessors(self): - # TODO: - # - Tに含まれるOpsをReadに含める - # - RangifyしたらSymbolicのDepsは消える - pass + def predecessors(self) -> tuple[ATenOp, ...]: + return tuple(args) + tuple(*[tuple(axis.size, axis.stride, axis.offset, axis.incf) for axis in self.T.axes]) + () if self.offset is None else tuple([self.offset]) @classmethod def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs) -> ATenOpType: diff --git a/caten/simplifier.py b/caten/simplifier.py new file mode 100644 index 00000000..47e8b9f9 --- /dev/null +++ b/caten/simplifier.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import inspect +from dataclasses import is_dataclass, replace +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + +from .ir import ATenOp, Const, Add, Mul + +OpType = Type[ATenOp] + +class Pat: + def __init__( + self, + op: Union[OpType, Tuple[OpType, ...], None] = None, + name: Optional[str] = None, + src: Optional[Tuple["Pat", ...]] = None, + meta: Optional[Dict[str, Callable[[ATenOp], Any]]] = None, + ): + self.op = op if (op is None or isinstance(op, tuple)) else (op,) + self.name, self.src = name, src + self.meta = meta or {} + def match(self, n: ATenOp, ctx: Dict[str, Any]) -> bool: + if self.op is not None and not isinstance(n, self.op): return False + if self.src is not None: + ks = tuple(n.args) + if len(ks) != len(self.src): return False + if not all(p.match(k, ctx) for p, k in zip(self.src, ks, strict=True)): return False + for var, ex in self.meta.items(): + v = ex(n) + if var in ctx and ctx[var] != v: return False + ctx[var] = v + if self.name: + if self.name in ctx and ctx[self.name] != n: return False + ctx[self.name] = n + return True + @staticmethod + def var(name: str) -> "Pat": + return Pat(name=name) + +class Simplifier: + def __init__(self, patterns: List[Tuple[Pat, Callable[..., Any]]]): + self.patterns = patterns + + def rewrite(self, n: ATenOp, ctx_obj: Any = None) -> Optional[ATenOp]: + for pat, fn in self.patterns: + m: Dict[str, Any] = {} + if not pat.match(n, m): continue + + sig = inspect.signature(fn) + argv = [(ctx_obj if p == "ctx" else m.get(p)) for p in sig.parameters] + out = fn(*argv) + + if out is None: continue + if isinstance(out, ATenOp): return out + raise TypeError(f"rewrite returned unsupported type: {type(out)}") + return None + + def _walk_once(self, root: ATenOp, ctx_obj: Any = None) -> Tuple[ATenOp, bool]: + changed = False + memo: Dict[int, ATenOp] = {} + def go(n: ATenOp) -> ATenOp: + nonlocal changed + if (r := memo.get(id(n))) is not None: return r + ks = tuple(n.args) + if ks: + ks2 = tuple(go(k) for k in ks) + if ks2 != ks and is_dataclass(n): + n = replace(n, args=type(n.args)(ks2)) + changed = True + while True: + r2 = self.rewrite(n, ctx_obj) + if r2 is None or r2 == n: break + n = r2 + changed = True + memo[id(n)] = n + return n + return go(root), changed + + def simplify(self, root: ATenOp, ctx_obj: Any = None, max_iters: int = 30000) -> ATenOp: + cur = root + for _ in range(max_iters): + cur, ch = self._walk_once(cur, ctx_obj) + if not ch: break + return cur From 5820b6740a38238bb8e9877f1ed775af55745a39 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 20:13:19 +0900 Subject: [PATCH 32/53] Feat: Simplifier --- caten/__init__.py | 2 ++ caten/ir.py | 10 +++++----- caten/runtime/cpu.py | 6 ++++++ caten/simplifier.py | 37 +++++++++++++++++++++++++++++++++---- caten/tensor.py | 6 +++++- 5 files changed, 51 insertions(+), 10 deletions(-) diff --git a/caten/__init__.py b/caten/__init__.py index a98f46ef..ae1d2c88 100644 --- a/caten/__init__.py +++ b/caten/__init__.py @@ -2,10 +2,12 @@ from .dtype import * from .tensor import * from .runtime import cpu +from .simplifier import * __all__ = [ "dtype", "helpers", "ir", "tensor" + "simplifier" ] diff --git a/caten/ir.py b/caten/ir.py index e0c64881..452640f3 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -21,9 +21,9 @@ def _freeze(x: Any) -> Any: def __call__(cls, args: tuple[ATenOp, ...], T: "ATenOpType | None" = None, **kwargs): T = cls.verify(args, T, **kwargs) # run type inference+verification wret = ATenOpMetaclass.cache.get(key:=(cls, tuple(args), ATenOpMetaclass._freeze(T), ATenOpMetaclass._freeze(kwargs)), None) - if wret is not None and (ret:=wret()) is not None: return ret + if wret is not None and (ret:=wret()) is not None: return ret.simplify() ATenOpMetaclass.cache[key] = weakref.ref(created:=super().__call__(args, T=T, **kwargs)) - return created + return created.simplify() @dataclass(frozen=True) class ATenAxis(): @@ -73,9 +73,9 @@ def predecessors(self) -> tuple[ATenOp, ...]: def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs) -> ATenOpType: raise NotImplementedError("Not implemented") - def coalese(self): - # Simplify myself - pass + def simplify(self): + from caten.simplifier import simplifier + return simplifier.simplify(self) def deepwalk(self): pass diff --git a/caten/runtime/cpu.py b/caten/runtime/cpu.py index 6c06816c..76309c94 100644 --- a/caten/runtime/cpu.py +++ b/caten/runtime/cpu.py @@ -7,7 +7,13 @@ def allocate(self): def free(self): pass + #@staticmethod def compile(self): pass + @staticmethod + def render(op): + def _render(node): + pass + C.ATenBase.register("CPU", CPUTensor) diff --git a/caten/simplifier.py b/caten/simplifier.py index 47e8b9f9..d173d31e 100644 --- a/caten/simplifier.py +++ b/caten/simplifier.py @@ -1,11 +1,13 @@ from __future__ import annotations -import inspect +import inspect, operator +import math + from dataclasses import is_dataclass, replace from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union -from .ir import ATenOp, Const, Add, Mul - +from .ir import ATenOp +import caten.ir as ir OpType = Type[ATenOp] class Pat: @@ -40,7 +42,8 @@ def var(name: str) -> "Pat": class Simplifier: def __init__(self, patterns: List[Tuple[Pat, Callable[..., Any]]]): self.patterns = patterns - + def __add__(self, other: Simplifier) -> Simplifier: + return Simplifier(self.patterns+other.patterns) def rewrite(self, n: ATenOp, ctx_obj: Any = None) -> Optional[ATenOp]: for pat, fn in self.patterns: m: Dict[str, Any] = {} @@ -82,3 +85,29 @@ def simplify(self, root: ATenOp, ctx_obj: Any = None, max_iters: int = 30000) -> cur, ch = self._walk_once(cur, ctx_obj) if not ch: break return cur + +# Guard Methods +def Guard(obj): pass +def _is_num(x: Any) -> bool: + return isinstance(x, (int, float)) and not isinstance(x, bool) + +constant_folder = Simplifier( + # UnaryOps + [( + Pat(ops[0], src=(Pat(ir.Const, name="x"))), + lambda x: ir.Const.new(ops[1](x.value), x.T.dtype) + if _is_num(x.value) + else None,) + for ops in [(ir.Sin, math.sin)]] + + # BinaryOps + [( + Pat(ops[0], src=(Pat(ir.Const, name="a"), Pat(ir.Const, name="b"))), + lambda a, b: ir.Const.new(ops[1](a.value, b.value), a.T.dtype) + if (a.T.dtype == b.T.dtype and _is_num(a.value) and _is_num(b.value)) + else None,) + for ops in [(ir.Add, operator.add), (ir.Mul, operator.mul)] + ] + # Ternary Ops? +) + +simplifier = constant_folder diff --git a/caten/tensor.py b/caten/tensor.py index 7b849dd8..73c27ca4 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -147,6 +147,11 @@ def free(self): def compile(self): pass + @staticmethod + @abstractmethod + def render(op: ATenOp): + pass + class Tensor(ATenBase): def __new__(cls, *args, **kwargs): impl = DEVICE_TO_TENSOR.get(get_backend()) @@ -157,7 +162,6 @@ def kernel(get_kernel: bool = False) -> Callable: def decorator(func: Callable) -> Callable: pass return decorator - # how to generate polyhedral model from tensor ops? # rangeify -> range/when ==> polyhedral model # with C.range(10, 10): From fa3ae24e4d68340c68b33d4d7b23cbbeb4d2edb1 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 20:37:18 +0900 Subject: [PATCH 33/53] Feat: Simplifier --- caten/ir.py | 21 ++++++++++++++------- caten/tensor.py | 35 ++++++++++++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index 452640f3..640c456b 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -35,6 +35,10 @@ def index(self, i: ATenOp): assert i.T.dtype == index, "ATenAxis.index: range index should be type of index." return Mul(self.stride, Add(Mul(i, self.incf), self.offset)) +def _const(val: int): + if isinstance(val, Const): return val + else: return Const.new(val, index) + @dataclass(frozen=True) class ATenOpType(): axes: tuple[ATenAxis] @@ -50,14 +54,10 @@ def index(self, indices: List[ATenOp]): def ndim(self): return len(self.axes) @staticmethod def from_shape(shape: List[Any], dtype: DType) -> ATenOpType: - def _const(val: int): return Const.new(val, index) - def _mul(a, b): - if not isinstance(a, Const): a = _const(a) - if not isinstance(b, Const): b = _const(b) - return Mul((a, b)) + def _mul(a, b): return Mul((_const(a), _const(b))) strides = tuple(itertools.accumulate(reversed(shape[1:]), _mul, initial=_const(1)))[::-1] return ATenOpType( - axes=tuple([ATenAxis(size=size, stride=stride, offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides)]), + axes=tuple([ATenAxis(size=_const(size), stride=_const(stride), offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides)]), dtype=dtype, ) @@ -82,6 +82,13 @@ def deepwalk(self): def viz(self): pass + + # Mixin for computing shapes (required by reshape, etc) + def __add__(self, other: Any): return Add((self, _const(other))) + def __radd__(self, other: Any): return Add((_const(other), self)) + def __mul__(self, other: Any): return Mul((self, _const(other))) + def __rmul__(self, other: Any): return Mul((_const(other), self)) + ## == Tensor Graph ============================================================ class UnaryOps(): # ops whose first argument is returned dtype @@ -243,7 +250,7 @@ def _expand(old_axis: ATenAxis, new_size: ATenOp) -> ATenAxis: if old_axis.size == new_size: return old_axis else: assert old_axis == -1 - return ATenAxis(size=new_size, stride=Const.new(0, index), offset=Const.new(0, index), incf=Const.new(1, index)) + return ATenAxis(size=_const(new_size), stride=Const.new(0, index), offset=Const.new(0, index), incf=Const.new(1, index)) return View((tensor,), T=ATenOpType( axes=tuple([_expand(old_axis, new_size) for (old_axis, new_size) in zip(tensor.T.axes, shape)]), dtype=tensor.T.dtype, diff --git a/caten/tensor.py b/caten/tensor.py index 73c27ca4..e9b9867c 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -102,12 +102,41 @@ def permute(self, order, *args) -> Self: def expand(self, shape, *args) -> Self: new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(align_left(self.shape, argfix(shape, *args))))) return self._broadcast_to([ATen.wrap_const(s, dtype=index) for s in new_shape]) +def smax(a, b): + # [TODO] Const fold and == 1 is required + print("TODO: SMAX") + return max(a, b) ## arithmetic mixin class ATenArith(): + def _broadcasted(self, y:Tensor|int|float, reverse:bool=False) -> tuple[Tensor, Tensor]: + x: ATen = self + if not isinstance(y, Tensor): + y = Tensor.const(y, dtype=x.dtype) + if x.dtype != y.dtype: + raise TypeError("Cannot add x and y (dtypes mismatch, todo)") + if reverse: x, y = y, x + # compute the output shape + def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]: + return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes))) + out_shape = _broadcast_shape(x.shape, y.shape) + return x._broadcast_to(out_shape), y._broadcast_to(out_shape) + + # TODO: + # - reduce option + # - ir.Add.new (or binop) can have reduce option + @ATen.top + def add(self, other, reverse:bool=False): return self.forward(ir.Add, tuple(self._broadcasted(self, other, reverse=reverse))) @ATen.top - def add(self, other): + def mul(self, other, reverse:bool=False): return self.forward(ir.Mul, tuple(self._broadcasted(self, other, reverse=reverse))) + def __eq__(self, other: Any): + print("A") pass - # TODO: self == 1 is evalued to true if self is const + def __add__(self, other: Any): return self.add(other) + def __radd__(self, other: Any): return self.add(other, reverse=True) + def __mul__(self, other: Any): return self.mul(other) + def __rmul__(self, other: Any): return self.mul(other, reverse=True) + + # TODO: self == 1 is evalued to true if self is const ## math mixin class ATenMath(): @ATen.top @@ -126,7 +155,7 @@ class Facet(): # TODO: with tensor.facet("CUDA") as tensor: ... pass ## abstraction over backends -class ATenBase(ATen, ATenMath, ATenNN, ATenMovements, ATenLinalg, metaclass=ABCMeta): +class ATenBase(ATen, ATenMovements, ATenArith,ATenMath, ATenNN, ATenLinalg, metaclass=ABCMeta): def __init__(self, *args, op=None): self.op = op From d43c14be5dc1ce9c62030a323061322f29e090d9 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 21:03:04 +0900 Subject: [PATCH 34/53] Reshape symbolic --- caten/ir.py | 37 +++++++++++++++++++++++++++++++------ caten/tensor.py | 20 ++++++++++++-------- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index 640c456b..b3d47978 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -35,9 +35,9 @@ def index(self, i: ATenOp): assert i.T.dtype == index, "ATenAxis.index: range index should be type of index." return Mul(self.stride, Add(Mul(i, self.incf), self.offset)) -def _const(val: int): +def _const(val: int, dtype: DType=index): if isinstance(val, Const): return val - else: return Const.new(val, index) + else: return Const.new(val, dtype) @dataclass(frozen=True) class ATenOpType(): @@ -60,7 +60,7 @@ def _mul(a, b): return Mul((_const(a), _const(b))) axes=tuple([ATenAxis(size=_const(size), stride=_const(stride), offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides)]), dtype=dtype, ) - + @dataclass(frozen=True) class ATenOp(metaclass=ATenOpMetaclass): args: List[ATenOp] @@ -83,12 +83,37 @@ def deepwalk(self): def viz(self): pass + @property + def item(self): + # Returns scalar value if self is constant folded + if isinstance(self, Const) and isinstance(getattr(self, "value"), (int, float)): + return self.value + else: return self # Mixin for computing shapes (required by reshape, etc) + # TODO: Use same semantic of broadcast as tensor def __add__(self, other: Any): return Add((self, _const(other))) def __radd__(self, other: Any): return Add((_const(other), self)) def __mul__(self, other: Any): return Mul((self, _const(other))) def __rmul__(self, other: Any): return Mul((_const(other), self)) - + @staticmethod + def eql(a: Union[int, float, ATenOp], b: Union[int, float, ATenOp]): + """ + """ + if isinstance(a, (int, float)) and isinstance(b, (int, float)): return (a == b) + dtype = a.T.dtype if isinstance(a, ATenOp) else b.T.dtype # A or B is asserted to have a dtype + a, b = _const(a, dtype=dtype), _const(b, dtype=dtype) + # Note(hikettei): this comparison highly depends on whether they are constant folded. + # plus, cannot verify the equivalence of A*B and B*A + return a == b + @staticmethod + def equals(a: List[Union[int, float, ATenOp]], b: List[Union[int, float, ATenOp]]): + """ + Compares the equivalence + """ + if not len(a) == len(b): return False + for ai, bi in zip(a, b): + if not ATenOp.eql(ai, bi): return False + return True ## == Tensor Graph ============================================================ class UnaryOps(): # ops whose first argument is returned dtype @@ -247,9 +272,9 @@ def permute(tensor: ATenOp, order: List[int]): @staticmethod def expand(tensor: ATenOp, shape: List[Union[int, ATenOp]]): def _expand(old_axis: ATenAxis, new_size: ATenOp) -> ATenAxis: - if old_axis.size == new_size: return old_axis + if ATenOp.eql(old_axis.size, new_size): return old_axis else: - assert old_axis == -1 + assert ATenOp.eql(old_axis, 1), f"The axis to expand should be evaluated to 1, getting {old_axis}" return ATenAxis(size=_const(new_size), stride=Const.new(0, index), offset=Const.new(0, index), incf=Const.new(1, index)) return View((tensor,), T=ATenOpType( axes=tuple([_expand(old_axis, new_size) for (old_axis, new_size) in zip(tensor.T.axes, shape)]), diff --git a/caten/tensor.py b/caten/tensor.py index e9b9867c..fc68246f 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -70,25 +70,25 @@ def _broadcast_to(self, new_shape: List[ATen]) -> Self: """ Implements Numpy-Semantic Broadcasting operation """ - if self.shape == new_shape: return self + if ir.ATenOp.equals(self.shape, new_shape): return self if self.ndim > len(new_shape): - raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape=}") + raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape}") shape, _ = align_left(self.shape, new_shape) - if not all(s == ns or s == 1 for s, ns in zip(shape, new_shape)): + if not all(ir.ATenOp.eql(s, ns) or ir.ATenOp.eql(s, 1) for s, ns in zip(shape, new_shape)): raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}") reshaped = self.reshape(shape) ret = Tensor(op=ir.View.expand(self.op, new_shape)) - return reshaped if ret.shape == reshaped.shape else ret + return reshaped if ir.ATenOp.equals(ret.shape, reshaped.shape) else ret @ATen.top def reshape(self, shape, *args) -> Self: new_shape = tuple([s if s is not None else self.shape[i] for i, s in enumerate(argfix(shape, *args))]) if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}") - if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]) - if prod(self.shape) != prod(new_shape): + if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if ir.ATenOp.eql(s, -1) else s for s in new_shape]) + if not ir.ATenOp.eql(prod(self.shape), prod(new_shape)): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})") ret = Tensor(op=ir.View.reshape(self.op, [ATen.wrap_const(s, dtype=index) for s in new_shape])) - return self if ret.shape == self.shape else ret + return self if ir.ATenOp.equals(ret.shape, self.shape) else ret @ATen.top def shrink(self, arg: tuple[tuple[sint, sint] | None, ...]) -> Self: raise NotImplementedError("shrink todo") @@ -100,8 +100,9 @@ def permute(self, order, *args) -> Self: return Tensor(op=ir.View.permute(self.op, order_arg)) if order_arg != tuple(range(self.ndim)) else self @ATen.top def expand(self, shape, *args) -> Self: - new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(align_left(self.shape, argfix(shape, *args))))) + new_shape = tuple(from_ if ir.ATenOp.eql(to, -1) or to is None else to for from_, to in zip(*(align_left(self.shape, argfix(shape, *args))))) return self._broadcast_to([ATen.wrap_const(s, dtype=index) for s in new_shape]) + def smax(a, b): # [TODO] Const fold and == 1 is required print("TODO: SMAX") @@ -131,6 +132,9 @@ def mul(self, other, reverse:bool=False): return self.forward(ir.Mul, tuple(self def __eq__(self, other: Any): print("A") pass + def __neq__(self, other: Any): + print("B") + pass def __add__(self, other: Any): return self.add(other) def __radd__(self, other: Any): return self.add(other, reverse=True) def __mul__(self, other: Any): return self.mul(other) From 6437e6934712347507835299297bff2ccc010f5e Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 21:03:52 +0900 Subject: [PATCH 35/53] docs --- caten/ir.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/caten/ir.py b/caten/ir.py index b3d47978..5b0f2449 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -95,9 +95,11 @@ def __add__(self, other: Any): return Add((self, _const(other))) def __radd__(self, other: Any): return Add((_const(other), self)) def __mul__(self, other: Any): return Mul((self, _const(other))) def __rmul__(self, other: Any): return Mul((_const(other), self)) + # note: do not try to overload __eq__ since it is need to compute hash @staticmethod def eql(a: Union[int, float, ATenOp], b: Union[int, float, ATenOp]): """ + Compare two scalars (Python numbers or ATenOp scalars) for equality. """ if isinstance(a, (int, float)) and isinstance(b, (int, float)): return (a == b) dtype = a.T.dtype if isinstance(a, ATenOp) else b.T.dtype # A or B is asserted to have a dtype @@ -108,7 +110,7 @@ def eql(a: Union[int, float, ATenOp], b: Union[int, float, ATenOp]): @staticmethod def equals(a: List[Union[int, float, ATenOp]], b: List[Union[int, float, ATenOp]]): """ - Compares the equivalence + Compare two lists element-wise using `ATenOp.eql` """ if not len(a) == len(b): return False for ai, bi in zip(a, b): From f35442abc424a49b52c05df3d1573005bea2b10d Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 21:05:10 +0900 Subject: [PATCH 36/53] docs --- caten/ir.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index 5b0f2449..161f4484 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -147,27 +147,36 @@ class Neg(UnaryOps, ATenOp): """ OUT = -X """ - pass @dataclass(frozen=True) class Recip(UnaryOps, ATenOp): - pass + """ + OUT = 1/X + """ @dataclass(frozen=True) class Sin(UnaryOps, ATenOp): - pass + """ + OUT = sin(X) + """ @dataclass(frozen=True) class Exp2(UnaryOps, ATenOp): - pass + """ + OUT = exp2(X) + """ @dataclass(frozen=True) class Log2(UnaryOps, ATenOp): - pass + """ + OUT = log2(X) + """ @dataclass(frozen=True) class Sqrt(UnaryOps, ATenOp): - pass + """ + OUT = sqrt(X) + """ @dataclass(frozen=True) class Bitcast(UnaryOps, ATenOp): @@ -179,25 +188,24 @@ class Not(UnaryOps, ATenOp): Logical not if the X is a boolean otherwise lognot ~x """ - pass ### BinaryOps @dataclass(frozen=True) class Add(BinaryOps, ATenOp): """ OUT = Add(X, Y) """ - pass @dataclass(frozen=True) class Mul(BinaryOps, ATenOp): """ OUT = Mul(X, Y) """ - pass @dataclass(frozen=True) class IDiv(BinaryOps, ATenOp): - pass + """ + OUT = A // B + """ @dataclass(frozen=True) class And(BinaryOps, ATenOp): From fc442a6bdfc8ec32190e527dce9e07f41e7fab6a Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 21:06:38 +0900 Subject: [PATCH 37/53] smax --- caten/tensor.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/caten/tensor.py b/caten/tensor.py index fc68246f..cc54c201 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -102,11 +102,7 @@ def permute(self, order, *args) -> Self: def expand(self, shape, *args) -> Self: new_shape = tuple(from_ if ir.ATenOp.eql(to, -1) or to is None else to for from_, to in zip(*(align_left(self.shape, argfix(shape, *args))))) return self._broadcast_to([ATen.wrap_const(s, dtype=index) for s in new_shape]) - -def smax(a, b): - # [TODO] Const fold and == 1 is required - print("TODO: SMAX") - return max(a, b) + ## arithmetic mixin class ATenArith(): def _broadcasted(self, y:Tensor|int|float, reverse:bool=False) -> tuple[Tensor, Tensor]: @@ -118,6 +114,12 @@ def _broadcasted(self, y:Tensor|int|float, reverse:bool=False) -> tuple[Tensor, if reverse: x, y = y, x # compute the output shape def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]: + def smax(a, b): + if ir.ATenOp.eql(a, 1): return b + elif ir.ATenOp.eql(b, 1): return a + else: + assert ir.ATenOp.eql(a, b) + return a # a != b is asserted here? return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes))) out_shape = _broadcast_shape(x.shape, y.shape) return x._broadcast_to(out_shape), y._broadcast_to(out_shape) From 3061ab2affc8bfad37676e4ef4a9d2237cb6b796 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 21:09:43 +0900 Subject: [PATCH 38/53] write specs on def --- caten/ir.py | 4 ++++ caten/simplifier.py | 13 +++++++------ caten/tensor.py | 3 --- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index 161f4484..5d22d843 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -4,6 +4,7 @@ from typing import List, Dict, Any, Union import itertools, weakref, dataclasses from dataclasses import dataclass +import operator, math from .dtype import DType, index class ATenOpMetaclass(type): @@ -159,6 +160,7 @@ class Sin(UnaryOps, ATenOp): """ OUT = sin(X) """ + python_op = math.sin @dataclass(frozen=True) class Exp2(UnaryOps, ATenOp): @@ -194,12 +196,14 @@ class Add(BinaryOps, ATenOp): """ OUT = Add(X, Y) """ + python_op = operator.add @dataclass(frozen=True) class Mul(BinaryOps, ATenOp): """ OUT = Mul(X, Y) """ + python_op = operator.mul @dataclass(frozen=True) class IDiv(BinaryOps, ATenOp): diff --git a/caten/simplifier.py b/caten/simplifier.py index d173d31e..665f1ccc 100644 --- a/caten/simplifier.py +++ b/caten/simplifier.py @@ -94,18 +94,19 @@ def _is_num(x: Any) -> bool: constant_folder = Simplifier( # UnaryOps [( - Pat(ops[0], src=(Pat(ir.Const, name="x"))), - lambda x: ir.Const.new(ops[1](x.value), x.T.dtype) + Pat(ops, src=(Pat(ir.Const, name="x"))), + lambda x: ir.Const.new(ops.python_op(x.value), x.T.dtype) if _is_num(x.value) else None,) - for ops in [(ir.Sin, math.sin)]] + + for ops in [ir.Sin] + ] + # BinaryOps [( - Pat(ops[0], src=(Pat(ir.Const, name="a"), Pat(ir.Const, name="b"))), - lambda a, b: ir.Const.new(ops[1](a.value, b.value), a.T.dtype) + Pat(ops, src=(Pat(ir.Const, name="a"), Pat(ir.Const, name="b"))), + lambda a, b: ir.Const.new(ops.python_op(a.value, b.value), a.T.dtype) if (a.T.dtype == b.T.dtype and _is_num(a.value) and _is_num(b.value)) else None,) - for ops in [(ir.Add, operator.add), (ir.Mul, operator.mul)] + for ops in [ir.Add, ir.Mul] ] # Ternary Ops? ) diff --git a/caten/tensor.py b/caten/tensor.py index cc54c201..8497f5dc 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -123,7 +123,6 @@ def smax(a, b): return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes))) out_shape = _broadcast_shape(x.shape, y.shape) return x._broadcast_to(out_shape), y._broadcast_to(out_shape) - # TODO: # - reduce option # - ir.Add.new (or binop) can have reduce option @@ -141,8 +140,6 @@ def __add__(self, other: Any): return self.add(other) def __radd__(self, other: Any): return self.add(other, reverse=True) def __mul__(self, other: Any): return self.mul(other) def __rmul__(self, other: Any): return self.mul(other, reverse=True) - - # TODO: self == 1 is evalued to true if self is const ## math mixin class ATenMath(): @ATen.top From d6e83da1c734a729015d11706b11263ad15e8ecb Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 21:12:58 +0900 Subject: [PATCH 39/53] write specs on def --- caten/ir.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index 5d22d843..deeaf2db 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -148,12 +148,14 @@ class Neg(UnaryOps, ATenOp): """ OUT = -X """ + python_op = lambda x: -x @dataclass(frozen=True) class Recip(UnaryOps, ATenOp): """ OUT = 1/X """ + python_op = lambda x: 1/x @dataclass(frozen=True) class Sin(UnaryOps, ATenOp): @@ -167,18 +169,21 @@ class Exp2(UnaryOps, ATenOp): """ OUT = exp2(X) """ + python_op = math.exp2 @dataclass(frozen=True) class Log2(UnaryOps, ATenOp): """ OUT = log2(X) """ + python_op = math.log2 @dataclass(frozen=True) class Sqrt(UnaryOps, ATenOp): """ OUT = sqrt(X) """ + python_op = math.sqrt @dataclass(frozen=True) class Bitcast(UnaryOps, ATenOp): @@ -210,6 +215,7 @@ class IDiv(BinaryOps, ATenOp): """ OUT = A // B """ + python_op = operator.floordiv @dataclass(frozen=True) class And(BinaryOps, ATenOp): @@ -229,30 +235,30 @@ class Xor(BinaryOps, ATenOp): @dataclass(frozen=True) class Max(BinaryOps, ATenOp): - pass + python_op = max @dataclass(frozen=True) class Mod(BinaryOps, ATenOp): - pass + python_op = operator.mod @dataclass(frozen=True) class Neq(BinaryOps, ATenOp): - pass + python_op = operator.ne @dataclass(frozen=True) class Lt(BinaryOps, ATenOp): - pass + python_op = operator.lt ### TernaryOps @dataclass(frozen=True) class Where(TernaryOps, ATenOp): - pass + python_op = lambda a, b, c: b if a else c ### Allocation @dataclass(frozen=True) class Const(ViewOps, ATenOp): - value: Union[int, float, str] = 0.0 + value: Union[int, float, str, bool] = 0.0 @staticmethod - def new(value: Union[int, float, str], dtype: DType): + def new(value: Union[int, float, str, bool], dtype: DType): return Const(args=(), value=value, T=ATenOpType(axes=(), dtype=dtype)) @dataclass(frozen=True) From 10b0a412dfa04b4b2a54a0caa2addf2310991bf4 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 21:28:21 +0900 Subject: [PATCH 40/53] Fix scope issue --- caten/simplifier.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/caten/simplifier.py b/caten/simplifier.py index 665f1ccc..db135c63 100644 --- a/caten/simplifier.py +++ b/caten/simplifier.py @@ -94,21 +94,18 @@ def _is_num(x: Any) -> bool: constant_folder = Simplifier( # UnaryOps [( - Pat(ops, src=(Pat(ir.Const, name="x"))), - lambda x: ir.Const.new(ops.python_op(x.value), x.T.dtype) - if _is_num(x.value) - else None,) - for ops in [ir.Sin] - ] + + Pat(op, src=(Pat(ir.Const, name="x"),)), + (lambda op: (lambda x: ir.Const.new(op.python_op(x.value), x.T.dtype) + if _is_num(x.value) else None))(op), + ) for op in [ir.Neg, ir.Recip, ir.Sin, ir.Exp2, ir.Log2, ir.Sqrt]] + + # BinaryOps [( - Pat(ops, src=(Pat(ir.Const, name="a"), Pat(ir.Const, name="b"))), - lambda a, b: ir.Const.new(ops.python_op(a.value, b.value), a.T.dtype) - if (a.T.dtype == b.T.dtype and _is_num(a.value) and _is_num(b.value)) - else None,) - for ops in [ir.Add, ir.Mul] - ] - # Ternary Ops? + Pat(op, src=(Pat(ir.Const, name="a"), Pat(ir.Const, name="b"))), + (lambda op: (lambda a, b: ir.Const.new(op.python_op(a.value, b.value), a.T.dtype) + if (a.T.dtype == b.T.dtype and _is_num(a.value) and _is_num(b.value)) else None))(op), + ) for op in [ir.Add, ir.Mul, ir.IDiv, ir.Max, ir.Mod, ir.Neq, ir.Lt]] ) + simplifier = constant_folder From b0a7749f3af44b73c797f8629791a30a9b2a232b Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 21:36:47 +0900 Subject: [PATCH 41/53] Fix scope issue --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a9bb6610..339fd684 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["*.ipynb"] [tool.ruff.lint] select = ["E", "F", "I", "B", "Q"] -ignore = ["E501"] +ignore = ["E501", "E701", "E401", "E731"] [tool.mypy] python_version = "3.11" From c34c90acc9cddfeba27e735aaaeb1650560c7fde Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 21:41:34 +0900 Subject: [PATCH 42/53] fix ruff --- caten/__init__.py | 3 +-- caten/dtype.py | 4 +++- caten/helpers.py | 4 +++- caten/ir.py | 18 +++++++++--------- caten/simplifier.py | 8 ++++---- caten/tensor.py | 22 +++++++++++++--------- test/test_kernel.py | 33 +++++++++++++++++---------------- test/test_movements.py | 1 - 8 files changed, 50 insertions(+), 43 deletions(-) diff --git a/caten/__init__.py b/caten/__init__.py index ae1d2c88..0d71796a 100644 --- a/caten/__init__.py +++ b/caten/__init__.py @@ -1,8 +1,7 @@ -from . import dtype, helpers, ir, tensor from .dtype import * -from .tensor import * from .runtime import cpu from .simplifier import * +from .tensor import * __all__ = [ "dtype", diff --git a/caten/dtype.py b/caten/dtype.py index a9f7963d..d8c1a17f 100644 --- a/caten/dtype.py +++ b/caten/dtype.py @@ -1,5 +1,7 @@ from __future__ import annotations -from dataclasses import dataclass, fields + +from dataclasses import dataclass + class DTypeMetaClass(type): dcache: dict[tuple, DType] = {} diff --git a/caten/helpers.py b/caten/helpers.py index 1ad5a204..48f3529d 100644 --- a/caten/helpers.py +++ b/caten/helpers.py @@ -1,6 +1,8 @@ from __future__ import annotations + +import functools +import operator from typing import Iterable, TypeVar -import functools, operator T = TypeVar("T") diff --git a/caten/ir.py b/caten/ir.py index deeaf2db..ba64c837 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -1,12 +1,16 @@ from __future__ import annotations -from abc import ABCMeta, abstractmethod -from typing import List, Dict, Any, Union -import itertools, weakref, dataclasses +import dataclasses +import itertools +import math +import operator +import weakref from dataclasses import dataclass -import operator, math +from typing import Any, Dict, List, Union + from .dtype import DType, index + class ATenOpMetaclass(type): cache: Dict[tuple, weakref.ReferenceType[ATenOp]] = {} @staticmethod @@ -87,7 +91,7 @@ def viz(self): @property def item(self): # Returns scalar value if self is constant folded - if isinstance(self, Const) and isinstance(getattr(self, "value"), (int, float)): + if isinstance(self, Const) and isinstance(self.value, (int, float)): return self.value else: return self # Mixin for computing shapes (required by reshape, etc) @@ -225,10 +229,6 @@ class And(BinaryOps, ATenOp): class Or(BinaryOps, ATenOp): pass -@dataclass(frozen=True) -class And(BinaryOps, ATenOp): - pass - @dataclass(frozen=True) class Xor(BinaryOps, ATenOp): pass diff --git a/caten/simplifier.py b/caten/simplifier.py index db135c63..97b28fe5 100644 --- a/caten/simplifier.py +++ b/caten/simplifier.py @@ -1,13 +1,13 @@ from __future__ import annotations -import inspect, operator -import math - +import inspect from dataclasses import is_dataclass, replace from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union -from .ir import ATenOp import caten.ir as ir + +from .ir import ATenOp + OpType = Type[ATenOp] class Pat: diff --git a/caten/tensor.py b/caten/tensor.py index 8497f5dc..7f7a08ae 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -1,10 +1,14 @@ from __future__ import annotations -from abc import ABCMeta, abstractmethod -from typing import Any, Optional, Tuple, Union, ClassVar, Self + import os +from abc import ABCMeta, abstractmethod +from typing import Any, Callable, ClassVar, List, Self, Tuple, Union + import caten.ir as ir -from .dtype import default_float, index, floats, integers -from caten.helpers import argfix, prod, align_left +from caten.helpers import align_left, argfix, prod + +from .dtype import DType, default_float, floats, index, integers + ## Backend Abstraction DEVICE_TO_TENSOR = {} def get_backend(): return os.environ.get("BACKEND", "CPU") @@ -29,7 +33,7 @@ def const(cls, obj: Any, dtype: DType=index): case _: raise TypeError(f"ATen.const: Only integer or float objects can become constant! getting {obj}") return ir.Const.new(obj, dtype) def forward(self, op: Callable, *args: List, **kwargs) -> ATen: return Tensor(op=op(*args, **kwargs)) - def __class_getitem__(cls, item: Union[Any, Tuple[Any, ...]]) -> TensorSpec: return TensorSpec(item) + def __class_getitem__(cls, item: Union[Any, Tuple[Any, ...]]) -> ATenSpec: return ATenSpec(item) def __repr__(self) -> str: # TODO: Display Shape, realized buffer, etc. return f"{self.__class__.__name__}<{self.op}>" @@ -90,7 +94,7 @@ def reshape(self, shape, *args) -> Self: ret = Tensor(op=ir.View.reshape(self.op, [ATen.wrap_const(s, dtype=index) for s in new_shape])) return self if ir.ATenOp.equals(ret.shape, self.shape) else ret @ATen.top - def shrink(self, arg: tuple[tuple[sint, sint] | None, ...]) -> Self: + def shrink(self, arg: tuple[tuple[int, int] | None, ...]) -> Self: raise NotImplementedError("shrink todo") @ATen.top def permute(self, order, *args) -> Self: @@ -113,14 +117,14 @@ def _broadcasted(self, y:Tensor|int|float, reverse:bool=False) -> tuple[Tensor, raise TypeError("Cannot add x and y (dtypes mismatch, todo)") if reverse: x, y = y, x # compute the output shape - def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]: + def _broadcast_shape(*shapes:tuple[int, ...]) -> tuple[int, ...]: def smax(a, b): if ir.ATenOp.eql(a, 1): return b elif ir.ATenOp.eql(b, 1): return a else: assert ir.ATenOp.eql(a, b) return a # a != b is asserted here? - return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes))) + return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*align_left(*shapes))) out_shape = _broadcast_shape(x.shape, y.shape) return x._broadcast_to(out_shape), y._broadcast_to(out_shape) # TODO: @@ -181,7 +185,7 @@ def compile(self): @staticmethod @abstractmethod - def render(op: ATenOp): + def render(op): pass class Tensor(ATenBase): diff --git a/test/test_kernel.py b/test/test_kernel.py index a6e8fada..9c0b9ca4 100644 --- a/test/test_kernel.py +++ b/test/test_kernel.py @@ -1,27 +1,28 @@ import caten as C + def test_tensor(): tensor = C.Tensor.from_shape([10, 10], dtype=C.float32) print(tensor) print(tensor.op.T) print(tensor.reshape([2, 5, 10])) -def atest_matmul_kernel(): - @C.kernel() - def matmul(A: C.Tensor[N, K], B: C.Tensor[K, M]): - Out = C.Tensor(N, M, dtype=A.dtype) - with C.range(N) as i: - with C.range(M) as j: - acc = C.LocalVar(0.0) - with C.range(K) as k: - acc += + A[i, k] * B[k, j] - Out[i, j] = C.tanh(acc) - return Out +#def atest_matmul_kernel(): +# @C.kernel() +# def matmul(A: C.Tensor[N, K], B: C.Tensor[K, M]): +# Out = C.Tensor(N, M, dtype=A.dtype) +# with C.range(N) as i: +# with C.range(M) as j: +# acc = C.LocalVar(0.0) +# with C.range(K) as k: +# acc += + A[i, k] * B[k, j] +# Out[i, j] = C.tanh(acc) +# return Out # TODO: # 1. VMAP # 2. Symbolic - N = C.param("N") - tmp = C.randn(N, 10, 10) - a, b, c = C.randn(10, 10), C.randn(10, 10), C.randn(10, 10) - c = matmul(a, b, c) - tmp * c +# N = C.param("N") +# tmp = C.randn(N, 10, 10) +# a, b, c = C.randn(10, 10), C.randn(10, 10), C.randn(10, 10) +# c = matmul(a, b, c) +# tmp * c diff --git a/test/test_movements.py b/test/test_movements.py index 36ab322c..1fc350d3 100644 --- a/test/test_movements.py +++ b/test/test_movements.py @@ -1,4 +1,3 @@ -import pytest def test_reshape(): pass From 51abfad519800f137d46247ce488e6c89b9fb8c6 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 21:47:56 +0900 Subject: [PATCH 43/53] fix ruff --- caten/__init__.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/caten/__init__.py b/caten/__init__.py index 0d71796a..c61529e3 100644 --- a/caten/__init__.py +++ b/caten/__init__.py @@ -1,12 +1,4 @@ -from .dtype import * -from .runtime import cpu -from .simplifier import * -from .tensor import * - -__all__ = [ - "dtype", - "helpers", - "ir", - "tensor" - "simplifier" -] +from .dtype import * # noqa: F403, I001 +from .simplifier import * # noqa: F403, I001 +from .tensor import * # noqa: F403, I001 +from .runtime import cpu # noqa: I001, F401 From 62157359f24aae4d68bc6a97b6d0a177bc632798 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 21:51:21 +0900 Subject: [PATCH 44/53] fix ruff --- caten/ir.py | 10 +++++----- caten/runtime/cpu.py | 1 + caten/tensor.py | 10 +++++----- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index ba64c837..dfb65401 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -52,7 +52,7 @@ class ATenOpType(): is_ptr: bool = False # TODO: for vectorize? def index(self, indices: List[ATenOp]): assert self.ndim == len(indices) - total = itertools.accumulate([b.index(a) for (a, b) in zip(indices, self.axes)], lambda a, b: Add((a, b)), initial=Const.new(0, index)) + total = itertools.accumulate([b.index(a) for (a, b) in zip(indices, self.axes, strict=True)], lambda a, b: Add((a, b)), initial=Const.new(0, index)) if self.offset: total = Add([total, self.offset]) return total @property @@ -62,7 +62,7 @@ def from_shape(shape: List[Any], dtype: DType) -> ATenOpType: def _mul(a, b): return Mul((_const(a), _const(b))) strides = tuple(itertools.accumulate(reversed(shape[1:]), _mul, initial=_const(1)))[::-1] return ATenOpType( - axes=tuple([ATenAxis(size=_const(size), stride=_const(stride), offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides)]), + axes=tuple([ATenAxis(size=_const(size), stride=_const(stride), offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides, strict=True)]), dtype=dtype, ) @@ -72,7 +72,7 @@ class ATenOp(metaclass=ATenOpMetaclass): T: Union[ATenOpType, None] = None # this should be provided via T=... option, or inferred via verify method. @property def predecessors(self) -> tuple[ATenOp, ...]: - return tuple(args) + tuple(*[tuple(axis.size, axis.stride, axis.offset, axis.incf) for axis in self.T.axes]) + () if self.offset is None else tuple([self.offset]) + return tuple(self.args) + tuple(*[tuple(axis.size, axis.stride, axis.offset, axis.incf) for axis in self.T.axes]) + () if self.offset is None else tuple([self.offset]) @classmethod def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs) -> ATenOpType: @@ -118,7 +118,7 @@ def equals(a: List[Union[int, float, ATenOp]], b: List[Union[int, float, ATenOp] Compare two lists element-wise using `ATenOp.eql` """ if not len(a) == len(b): return False - for ai, bi in zip(a, b): + for ai, bi in zip(a, b, strict=True): if not ATenOp.eql(ai, bi): return False return True ## == Tensor Graph ============================================================ @@ -297,7 +297,7 @@ def _expand(old_axis: ATenAxis, new_size: ATenOp) -> ATenAxis: assert ATenOp.eql(old_axis, 1), f"The axis to expand should be evaluated to 1, getting {old_axis}" return ATenAxis(size=_const(new_size), stride=Const.new(0, index), offset=Const.new(0, index), incf=Const.new(1, index)) return View((tensor,), T=ATenOpType( - axes=tuple([_expand(old_axis, new_size) for (old_axis, new_size) in zip(tensor.T.axes, shape)]), + axes=tuple([_expand(old_axis, new_size) for (old_axis, new_size) in zip(tensor.T.axes, shape, strict=True)]), dtype=tensor.T.dtype, offset=tensor.T.offset, is_ptr=tensor.T.is_ptr diff --git a/caten/runtime/cpu.py b/caten/runtime/cpu.py index 76309c94..84fa9179 100644 --- a/caten/runtime/cpu.py +++ b/caten/runtime/cpu.py @@ -1,5 +1,6 @@ import caten as C + class CPUTensor(C.ATenBase): def allocate(self): pass diff --git a/caten/tensor.py b/caten/tensor.py index 7f7a08ae..4fa1f676 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -22,9 +22,9 @@ def __init__(self, shape: Tuple[Any, ...]): def __repr__(self) -> str: return f"ATenSpec{self.shape}" ## Tensor compiler core class ATen: - op: ATenOp # ATen is just a wrapper for ATenOp + op: ir.ATenOp # ATen is just a wrapper for ATenOp @classmethod - def from_shape(cls, shape: List[ATenOp], dtype: DType=default_float): return Tensor(op=ir.Allocate.new(shape, dtype)) + def from_shape(cls, shape: List[ir.ATenOp], dtype: DType=default_float): return Tensor(op=ir.Allocate.new(shape, dtype)) @classmethod def const(cls, obj: Any, dtype: DType=index): match obj: @@ -78,7 +78,7 @@ def _broadcast_to(self, new_shape: List[ATen]) -> Self: if self.ndim > len(new_shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape}") shape, _ = align_left(self.shape, new_shape) - if not all(ir.ATenOp.eql(s, ns) or ir.ATenOp.eql(s, 1) for s, ns in zip(shape, new_shape)): + if not all(ir.ATenOp.eql(s, ns) or ir.ATenOp.eql(s, 1) for s, ns in zip(shape, new_shape, strict=True)): raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}") reshaped = self.reshape(shape) ret = Tensor(op=ir.View.expand(self.op, new_shape)) @@ -104,7 +104,7 @@ def permute(self, order, *args) -> Self: return Tensor(op=ir.View.permute(self.op, order_arg)) if order_arg != tuple(range(self.ndim)) else self @ATen.top def expand(self, shape, *args) -> Self: - new_shape = tuple(from_ if ir.ATenOp.eql(to, -1) or to is None else to for from_, to in zip(*(align_left(self.shape, argfix(shape, *args))))) + new_shape = tuple(from_ if ir.ATenOp.eql(to, -1) or to is None else to for from_, to in zip(*(align_left(self.shape, argfix(shape, *args))), strict=True)) return self._broadcast_to([ATen.wrap_const(s, dtype=index) for s in new_shape]) ## arithmetic mixin @@ -124,7 +124,7 @@ def smax(a, b): else: assert ir.ATenOp.eql(a, b) return a # a != b is asserted here? - return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*align_left(*shapes))) + return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*align_left(*shapes), strict=True)) out_shape = _broadcast_shape(x.shape, y.shape) return x._broadcast_to(out_shape), y._broadcast_to(out_shape) # TODO: From e72883f986478d9a338b614a4033cafe46632d8e Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 22:18:40 +0900 Subject: [PATCH 45/53] helpers --- caten/helpers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/caten/helpers.py b/caten/helpers.py index 48f3529d..6b470019 100644 --- a/caten/helpers.py +++ b/caten/helpers.py @@ -2,19 +2,19 @@ import functools import operator -from typing import Iterable, TypeVar +from typing import Iterable, TypeVar, Any T = TypeVar("T") def prod(x:Iterable[T]) -> T|int: return functools.reduce(operator.mul, x, 1) -def argfix(*x): +def argfix(*x: Any) -> tuple[Any, ...]: if x and x[0].__class__ in (tuple, list): if len(x) != 1: raise ValueError(f"bad arg {x}") return tuple(x[0]) return x -def align_left(*shapes): +def align_left(*shapes: tuple[int, ...]) -> tuple[tuple[int, ...], ...]: # unsqueeze left to make every shape same length max_dim = max(len(shape) for shape in shapes) return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes) From 50372b0b72fe2759b966e336a98cfa2084e31af7 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 22:18:50 +0900 Subject: [PATCH 46/53] helpers --- caten/dtype.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/caten/dtype.py b/caten/dtype.py index d8c1a17f..b768dca2 100644 --- a/caten/dtype.py +++ b/caten/dtype.py @@ -3,9 +3,11 @@ from dataclasses import dataclass +from typing import Any + class DTypeMetaClass(type): dcache: dict[tuple, DType] = {} - def __call__(cls, *args, **kwargs): + def __call__(cls, *args: Any, **kwargs: Any) -> DType: if (ret:=DTypeMetaClass.dcache.get(args, None)) is not None: return ret DTypeMetaClass.dcache[args] = ret = super().__call__(*args) return ret @@ -15,7 +17,7 @@ def __call__(cls, *args, **kwargs): class DType: name: str @staticmethod - def new(name:str): return DType(name) + def new(name:str) -> DType: return DType(name) ## definitions float64 = DType.new("float64") From 77c91572e825cfbfd02d66eb4ef15dc2b05f7945 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 22:20:00 +0900 Subject: [PATCH 47/53] helpers --- test/test_kernel.py | 2 +- test/test_movements.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_kernel.py b/test/test_kernel.py index 9c0b9ca4..70e76972 100644 --- a/test/test_kernel.py +++ b/test/test_kernel.py @@ -1,7 +1,7 @@ import caten as C -def test_tensor(): +def test_tensor() -> None: tensor = C.Tensor.from_shape([10, 10], dtype=C.float32) print(tensor) print(tensor.op.T) diff --git a/test/test_movements.py b/test/test_movements.py index 1fc350d3..db8e22be 100644 --- a/test/test_movements.py +++ b/test/test_movements.py @@ -1,9 +1,9 @@ -def test_reshape(): +def test_reshape() -> None: pass -def test_reshape_const(): +def test_reshape_const() -> None: pass -def test_reshape_dynamic(): +def test_reshape_dynamic() -> None: pass From 4700274de2b319b8cfa7c8907a12cc82a4b68743 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 22:20:11 +0900 Subject: [PATCH 48/53] helpers --- caten/ir.py | 91 ++++++++++++++++++++++++-------------------- caten/runtime/cpu.py | 11 +++--- 2 files changed, 55 insertions(+), 47 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index dfb65401..8c83eb85 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -6,7 +6,7 @@ import operator import weakref from dataclasses import dataclass -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, Sequence from .dtype import DType, index @@ -23,11 +23,11 @@ def _freeze(x: Any) -> Any: if isinstance(x, dict): return tuple(sorted((k, ATenOpMetaclass._freeze(v)) for k, v in x.items())) return x - def __call__(cls, args: tuple[ATenOp, ...], T: "ATenOpType | None" = None, **kwargs): - T = cls.verify(args, T, **kwargs) # run type inference+verification + def __call__(cls, args: tuple[ATenOp, ...] | list[ATenOp], T: "ATenOpType | None" = None, **kwargs: Any) -> ATenOp: + T = cls.verify(tuple(args), T, **kwargs) # run type inference+verification wret = ATenOpMetaclass.cache.get(key:=(cls, tuple(args), ATenOpMetaclass._freeze(T), ATenOpMetaclass._freeze(kwargs)), None) if wret is not None and (ret:=wret()) is not None: return ret.simplify() - ATenOpMetaclass.cache[key] = weakref.ref(created:=super().__call__(args, T=T, **kwargs)) + ATenOpMetaclass.cache[key] = weakref.ref(created:=super().__call__(tuple(args), T=T, **kwargs)) return created.simplify() @dataclass(frozen=True) @@ -36,30 +36,30 @@ class ATenAxis(): stride: ATenOp offset: ATenOp incf: ATenOp - def index(self, i: ATenOp): - assert i.T.dtype == index, "ATenAxis.index: range index should be type of index." - return Mul(self.stride, Add(Mul(i, self.incf), self.offset)) + def index(self, i: ATenOp) -> ATenOp: + assert i.T is not None and i.T.dtype == index, "ATenAxis.index: range index should be type of index." + return Mul((self.stride, Add((Mul((i, self.incf)), self.offset)))) -def _const(val: int, dtype: DType=index): +def _const(val: int, dtype: DType=index) -> ATenOp: if isinstance(val, Const): return val else: return Const.new(val, dtype) @dataclass(frozen=True) class ATenOpType(): - axes: tuple[ATenAxis] + axes: tuple[ATenAxis, ...] dtype: DType offset: Union[ATenOp, None] = None is_ptr: bool = False # TODO: for vectorize? - def index(self, indices: List[ATenOp]): + def index(self, indices: List[ATenOp]) -> Any: assert self.ndim == len(indices) total = itertools.accumulate([b.index(a) for (a, b) in zip(indices, self.axes, strict=True)], lambda a, b: Add((a, b)), initial=Const.new(0, index)) - if self.offset: total = Add([total, self.offset]) + if self.offset: total = Add((total, self.offset)) # type: ignore return total @property - def ndim(self): return len(self.axes) + def ndim(self) -> int: return len(self.axes) @staticmethod def from_shape(shape: List[Any], dtype: DType) -> ATenOpType: - def _mul(a, b): return Mul((_const(a), _const(b))) + def _mul(a: Any, b: Any) -> Any: return Mul((_const(a), _const(b))) strides = tuple(itertools.accumulate(reversed(shape[1:]), _mul, initial=_const(1)))[::-1] return ATenOpType( axes=tuple([ATenAxis(size=_const(size), stride=_const(stride), offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides, strict=True)]), @@ -68,52 +68,53 @@ def _mul(a, b): return Mul((_const(a), _const(b))) @dataclass(frozen=True) class ATenOp(metaclass=ATenOpMetaclass): - args: List[ATenOp] + args: tuple[ATenOp, ...] T: Union[ATenOpType, None] = None # this should be provided via T=... option, or inferred via verify method. @property def predecessors(self) -> tuple[ATenOp, ...]: - return tuple(self.args) + tuple(*[tuple(axis.size, axis.stride, axis.offset, axis.incf) for axis in self.T.axes]) + () if self.offset is None else tuple([self.offset]) + return tuple(self.args) + (tuple(*[tuple((axis.size, axis.stride, axis.offset, axis.incf)) for axis in self.T.axes]) + () if self.T is not None else ()) + ((self.offset,) if self.T and self.T.offset is not None else ()) # type: ignore @classmethod - def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs) -> ATenOpType: + def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs: Any) -> ATenOpType: raise NotImplementedError("Not implemented") - def simplify(self): + def simplify(self) -> ATenOp: from caten.simplifier import simplifier return simplifier.simplify(self) - def deepwalk(self): + def deepwalk(self) -> None: pass - def viz(self): + def viz(self) -> None: pass @property - def item(self): + def item(self) -> Union[int, float, ATenOp]: # Returns scalar value if self is constant folded if isinstance(self, Const) and isinstance(self.value, (int, float)): return self.value else: return self # Mixin for computing shapes (required by reshape, etc) # TODO: Use same semantic of broadcast as tensor - def __add__(self, other: Any): return Add((self, _const(other))) - def __radd__(self, other: Any): return Add((_const(other), self)) - def __mul__(self, other: Any): return Mul((self, _const(other))) - def __rmul__(self, other: Any): return Mul((_const(other), self)) + def __add__(self, other: Any) -> ATenOp: return Add((self, _const(other))) + def __radd__(self, other: Any) -> ATenOp: return Add((_const(other), self)) + def __mul__(self, other: Any) -> ATenOp: return Mul((self, _const(other))) + def __rmul__(self, other: Any) -> ATenOp: return Mul((_const(other), self)) # note: do not try to overload __eq__ since it is need to compute hash @staticmethod - def eql(a: Union[int, float, ATenOp], b: Union[int, float, ATenOp]): + def eql(a: Union[int, float, ATenOp], b: Union[int, float, ATenOp]) -> bool: """ Compare two scalars (Python numbers or ATenOp scalars) for equality. """ if isinstance(a, (int, float)) and isinstance(b, (int, float)): return (a == b) + assert isinstance(a, ATenOp) and a.T is not None dtype = a.T.dtype if isinstance(a, ATenOp) else b.T.dtype # A or B is asserted to have a dtype - a, b = _const(a, dtype=dtype), _const(b, dtype=dtype) + a, b = _const(a, dtype=dtype), _const(b, dtype=dtype) # type: ignore # Note(hikettei): this comparison highly depends on whether they are constant folded. # plus, cannot verify the equivalence of A*B and B*A return a == b @staticmethod - def equals(a: List[Union[int, float, ATenOp]], b: List[Union[int, float, ATenOp]]): + def equals(a: List[Union[int, float, ATenOp]], b: List[Union[int, float, ATenOp]]) -> bool: """ Compare two lists element-wise using `ATenOp.eql` """ @@ -125,25 +126,28 @@ def equals(a: List[Union[int, float, ATenOp]], b: List[Union[int, float, ATenOp] class UnaryOps(): # ops whose first argument is returned dtype @classmethod - def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs) -> ATenOpType: + def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs: Any) -> ATenOpType: assert len(args) == 1, f"UnaryOp {cls.__name__} takes one argument, getting {args}" + assert args[0].T is not None return args[0].T class BinaryOps(): # ops whose first argument is returned dtype @classmethod - def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs) -> ATenOpType: + def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs: Any) -> ATenOpType: assert len(args) == 2, f"BinaryOp {cls.__name__} takes two argument, getting {args}" + assert args[0].T is not None return args[0].T class TernaryOps(): # ops whose first argument is returned dtype @classmethod - def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs) -> ATenOpType: + def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs: Any) -> ATenOpType: assert len(args) == 3, f"TernaryOp {cls.__name__} takes three argument, getting {args}" + assert args[0].T is not None return args[0].T class ViewOps(): # ops whose return dtypes are explicitly provided via T option @classmethod - def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs) -> ATenOpType: + def verify(cls, args: tuple[ATenOp, ...], T: Union[None, ATenOpType], **kwargs: Any) -> ATenOpType: assert T is not None, f"Cannot create {cls.__name__} without providing T" return T ### UnaryOps @@ -258,7 +262,7 @@ class Where(TernaryOps, ATenOp): class Const(ViewOps, ATenOp): value: Union[int, float, str, bool] = 0.0 @staticmethod - def new(value: Union[int, float, str, bool], dtype: DType): + def new(value: Union[int, float, str, bool], dtype: DType) -> Const: return Const(args=(), value=value, T=ATenOpType(axes=(), dtype=dtype)) @dataclass(frozen=True) @@ -267,7 +271,7 @@ class Allocate(ViewOps, ATenOp): Allocate(S1, S2, S3, ...) """ @staticmethod - def new(shape: List[Any], dtype: DType): + def new(shape: List[Any], dtype: DType) -> Allocate: return Allocate((), T=ATenOpType.from_shape(shape, dtype)) @dataclass(frozen=True) @@ -277,11 +281,13 @@ class View(ViewOps, ATenOp): """ # This is the definition of view @staticmethod - def reshape(tensor: ATenOp, shape: List[ATenOp]): + def reshape(tensor: ATenOp, shape: List[ATenOp]) -> View: + assert tensor.T is not None return View((tensor,), T=ATenOpType.from_shape(shape, tensor.T.dtype)) @staticmethod - def permute(tensor: ATenOp, order: List[int]): + def permute(tensor: ATenOp, order: List[int]) -> View: + assert tensor.T is not None return View((tensor,), T=ATenOpType( axes=tuple([tensor.T.axes[i] for i in order]), dtype=tensor.T.dtype, @@ -290,12 +296,13 @@ def permute(tensor: ATenOp, order: List[int]): )) @staticmethod - def expand(tensor: ATenOp, shape: List[Union[int, ATenOp]]): - def _expand(old_axis: ATenAxis, new_size: ATenOp) -> ATenAxis: + def expand(tensor: ATenOp, shape: List[Union[int, ATenOp]]) -> View: + assert tensor.T is not None + def _expand(old_axis: ATenAxis, new_size: int | float | ATenOp) -> ATenAxis: if ATenOp.eql(old_axis.size, new_size): return old_axis else: - assert ATenOp.eql(old_axis, 1), f"The axis to expand should be evaluated to 1, getting {old_axis}" - return ATenAxis(size=_const(new_size), stride=Const.new(0, index), offset=Const.new(0, index), incf=Const.new(1, index)) + assert ATenOp.eql(old_axis.size, 1), f"The axis to expand should be evaluated to 1, getting {old_axis}" # Fix: old_axis -> old_axis.size + return ATenAxis(size=_const(new_size), stride=Const.new(0, index), offset=Const.new(0, index), incf=Const.new(1, index)) # type: ignore return View((tensor,), T=ATenOpType( axes=tuple([_expand(old_axis, new_size) for (old_axis, new_size) in zip(tensor.T.axes, shape, strict=True)]), dtype=tensor.T.dtype, @@ -308,9 +315,9 @@ class Reduce(ATenOp): """ OUT = Reduce(A, B, op=BinaryOps) """ - op: BinaryOps = Add + op: type[BinaryOps] = Add @classmethod - def from_ast_expr(cls): + def from_ast_expr(cls) -> None: pass @dataclass(frozen=True) @@ -337,7 +344,7 @@ class Progn(ATenOp): class Polyhedral(ATenOp): pass -def Var(): +def Var() -> None: pass # e.g.: diff --git a/caten/runtime/cpu.py b/caten/runtime/cpu.py index 84fa9179..c6130328 100644 --- a/caten/runtime/cpu.py +++ b/caten/runtime/cpu.py @@ -1,20 +1,21 @@ +from typing import Any import caten as C class CPUTensor(C.ATenBase): - def allocate(self): + def allocate(self) -> None: pass - def free(self): + def free(self) -> None: pass #@staticmethod - def compile(self): + def compile(self) -> None: pass @staticmethod - def render(op): - def _render(node): + def render(op: Any) -> None: + def _render(node: Any) -> None: pass C.ATenBase.register("CPU", CPUTensor) From a2ce35965e52ef9e52af3abb9961efa081d99ab3 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 22:23:06 +0900 Subject: [PATCH 49/53] helpers --- caten/simplifier.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/caten/simplifier.py b/caten/simplifier.py index 97b28fe5..9b0151a0 100644 --- a/caten/simplifier.py +++ b/caten/simplifier.py @@ -87,22 +87,22 @@ def simplify(self, root: ATenOp, ctx_obj: Any = None, max_iters: int = 30000) -> return cur # Guard Methods -def Guard(obj): pass +#def Guard(obj): pass def _is_num(x: Any) -> bool: return isinstance(x, (int, float)) and not isinstance(x, bool) constant_folder = Simplifier( # UnaryOps [( - Pat(op, src=(Pat(ir.Const, name="x"),)), - (lambda op: (lambda x: ir.Const.new(op.python_op(x.value), x.T.dtype) + Pat(op, src=(Pat(ir.Const, name="x"),)), # type: ignore + (lambda op: (lambda x: ir.Const.new(op.python_op(x.value), x.T.dtype) # type: ignore if _is_num(x.value) else None))(op), ) for op in [ir.Neg, ir.Recip, ir.Sin, ir.Exp2, ir.Log2, ir.Sqrt]] + # BinaryOps [( - Pat(op, src=(Pat(ir.Const, name="a"), Pat(ir.Const, name="b"))), - (lambda op: (lambda a, b: ir.Const.new(op.python_op(a.value, b.value), a.T.dtype) + Pat(op, src=(Pat(ir.Const, name="a"), Pat(ir.Const, name="b"))), # type: ignore + (lambda op: (lambda a, b: ir.Const.new(op.python_op(a.value, b.value), a.T.dtype) # type: ignore if (a.T.dtype == b.T.dtype and _is_num(a.value) and _is_num(b.value)) else None))(op), ) for op in [ir.Add, ir.Mul, ir.IDiv, ir.Max, ir.Mod, ir.Neq, ir.Lt]] ) From dec32aedbad22ec128089aa23d6fba8ba2438a8c Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 22:36:26 +0900 Subject: [PATCH 50/53] helpers --- caten/helpers.py | 2 +- caten/ir.py | 2 +- caten/tensor.py | 31 ++++++++++++++++++------------- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/caten/helpers.py b/caten/helpers.py index 6b470019..cfb41c35 100644 --- a/caten/helpers.py +++ b/caten/helpers.py @@ -14,7 +14,7 @@ def argfix(*x: Any) -> tuple[Any, ...]: return tuple(x[0]) return x -def align_left(*shapes: tuple[int, ...]) -> tuple[tuple[int, ...], ...]: +def align_left(*shapes: tuple[Any, ...]) -> tuple[tuple[Any, ...], ...]: # unsqueeze left to make every shape same length max_dim = max(len(shape) for shape in shapes) return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes) diff --git a/caten/ir.py b/caten/ir.py index 8c83eb85..ca7b4195 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -114,7 +114,7 @@ def eql(a: Union[int, float, ATenOp], b: Union[int, float, ATenOp]) -> bool: # plus, cannot verify the equivalence of A*B and B*A return a == b @staticmethod - def equals(a: List[Union[int, float, ATenOp]], b: List[Union[int, float, ATenOp]]) -> bool: + def equals(a: tuple[Union[int, float, ATenOp], ...], b: tuple[Union[int, float, ATenOp], ...]) -> bool: """ Compare two lists element-wise using `ATenOp.eql` """ diff --git a/caten/tensor.py b/caten/tensor.py index 4fa1f676..29182435 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -11,46 +11,51 @@ ## Backend Abstraction DEVICE_TO_TENSOR = {} -def get_backend(): return os.environ.get("BACKEND", "CPU") +def get_backend() -> str: return os.environ.get("BACKEND", "CPU") ## Tensor annotation for jit/aot shape check class ATenSpec: """ C.Tensor[M, N] -> ATenSpec(M N) """ def __init__(self, shape: Tuple[Any, ...]): - self.shape: List[Union[int, str]] = shape + self.shape: tuple[Union[int, str], ...] = shape def __repr__(self) -> str: return f"ATenSpec{self.shape}" ## Tensor compiler core class ATen: op: ir.ATenOp # ATen is just a wrapper for ATenOp @classmethod - def from_shape(cls, shape: List[ir.ATenOp], dtype: DType=default_float): return Tensor(op=ir.Allocate.new(shape, dtype)) + def from_shape(cls, shape: List[ir.ATenOp], dtype: DType=default_float) -> Tensor: return Tensor(op=ir.Allocate.new(shape, dtype)) # type: ignore @classmethod - def const(cls, obj: Any, dtype: DType=index): + def const(cls, obj: Any, dtype: DType=index) -> ir.Const: match obj: case int(): assert dtype in integers case float(): assert dtype in floats case _: raise TypeError(f"ATen.const: Only integer or float objects can become constant! getting {obj}") return ir.Const.new(obj, dtype) - def forward(self, op: Callable, *args: List, **kwargs) -> ATen: return Tensor(op=op(*args, **kwargs)) + def forward(self, op: Callable, *args: List, **kwargs) -> Tensor: return Tensor(op=op(*args, **kwargs)) # type: ignore def __class_getitem__(cls, item: Union[Any, Tuple[Any, ...]]) -> ATenSpec: return ATenSpec(item) def __repr__(self) -> str: # TODO: Display Shape, realized buffer, etc. return f"{self.__class__.__name__}<{self.op}>" @property - def dtype(self): return self.op.T.dtype + def dtype(self) -> DType: + assert self.op.T is not None + return self.op.T.dtype @staticmethod - def wrap_const(obj: Any, dtype: DType = index): + def wrap_const(obj: Union[ATen, ir.ATenOp, float, int], dtype: DType = index) -> ir.ATenOp: """ Ensures obj is a constant of dtype """ if isinstance(obj, ATen): assert obj.dtype == dtype # todo: decent error msg + return obj.op + elif isinstance(obj, ir.ATenOp): + assert obj.T is not None and obj.T.dtype == dtype # todo: decent error msg return obj else: return ATen.const(obj, dtype=dtype) @staticmethod - def top(f: Callable[Any, ATen]): + def top(f: Callable) -> Callable: """ Declares the given function as toplevel tensor operation. """ @@ -59,9 +64,9 @@ def top(f: Callable[Any, ATen]): ## movement ops mixin class ATenMovements(): @property - def shape(self) -> List[ATen]: return [x.size for x in self.op.T.axes] + def shape(self) -> tuple[ir.ATenOp, ...]: return tuple([x.size for x in self.op.T.axes]) # type: ignore @property - def strides(self) -> List[ATen]: return [x.stride for x in self.op.T.axes] + def strides(self) -> tuple[ir.ATenOp, ...]: return tuple([x.stride for x in self.op.T.axes]) # type: ignore @property def ndim(self) -> int: return len(self.shape) def _resolve_dim(self, dim: int, *, extra: bool = False) -> int: @@ -70,7 +75,7 @@ def _resolve_dim(self, dim: int, *, extra: bool = False) -> int: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total) - 1]}") return dim + total if dim < 0 else dim # ref: https://github.com/tinygrad/tinygrad/blob/master/tinygrad/mixin/movement.py#L58 - def _broadcast_to(self, new_shape: List[ATen]) -> Self: + def _broadcast_to(self, new_shape: tuple[ir.ATenOp, ...]) -> Self: """ Implements Numpy-Semantic Broadcasting operation """ @@ -81,8 +86,8 @@ def _broadcast_to(self, new_shape: List[ATen]) -> Self: if not all(ir.ATenOp.eql(s, ns) or ir.ATenOp.eql(s, 1) for s, ns in zip(shape, new_shape, strict=True)): raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}") reshaped = self.reshape(shape) - ret = Tensor(op=ir.View.expand(self.op, new_shape)) - return reshaped if ir.ATenOp.equals(ret.shape, reshaped.shape) else ret + ret = Tensor(op=ir.View.expand(self.op, new_shape)) # type: ignore + return reshaped if ir.ATenOp.equals(ret.shape, reshaped.shape) else ret # type: ignore @ATen.top def reshape(self, shape, *args) -> Self: new_shape = tuple([s if s is not None else self.shape[i] for i, s in enumerate(argfix(shape, *args))]) From 8c5e351de6445b639f9cd5bf7f1beba58a89677f Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 22:54:20 +0900 Subject: [PATCH 51/53] helpers --- caten/ir.py | 6 ++-- caten/tensor.py | 81 +++++++++++++++++++++++++------------------------ 2 files changed, 45 insertions(+), 42 deletions(-) diff --git a/caten/ir.py b/caten/ir.py index ca7b4195..2521e4cf 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -24,7 +24,7 @@ def _freeze(x: Any) -> Any: return tuple(sorted((k, ATenOpMetaclass._freeze(v)) for k, v in x.items())) return x def __call__(cls, args: tuple[ATenOp, ...] | list[ATenOp], T: "ATenOpType | None" = None, **kwargs: Any) -> ATenOp: - T = cls.verify(tuple(args), T, **kwargs) # run type inference+verification + T = cls.verify(tuple(args), T, **kwargs) # type: ignore wret = ATenOpMetaclass.cache.get(key:=(cls, tuple(args), ATenOpMetaclass._freeze(T), ATenOpMetaclass._freeze(kwargs)), None) if wret is not None and (ret:=wret()) is not None: return ret.simplify() ATenOpMetaclass.cache[key] = weakref.ref(created:=super().__call__(tuple(args), T=T, **kwargs)) @@ -52,7 +52,7 @@ class ATenOpType(): is_ptr: bool = False # TODO: for vectorize? def index(self, indices: List[ATenOp]) -> Any: assert self.ndim == len(indices) - total = itertools.accumulate([b.index(a) for (a, b) in zip(indices, self.axes, strict=True)], lambda a, b: Add((a, b)), initial=Const.new(0, index)) + total = itertools.accumulate([b.index(a) for (a, b) in zip(indices, self.axes, strict=True)], lambda a, b: Add((a, b)), initial=Const.new(0, index)) # type: ignore if self.offset: total = Add((total, self.offset)) # type: ignore return total @property @@ -62,7 +62,7 @@ def from_shape(shape: List[Any], dtype: DType) -> ATenOpType: def _mul(a: Any, b: Any) -> Any: return Mul((_const(a), _const(b))) strides = tuple(itertools.accumulate(reversed(shape[1:]), _mul, initial=_const(1)))[::-1] return ATenOpType( - axes=tuple([ATenAxis(size=_const(size), stride=_const(stride), offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides, strict=True)]), + axes=tuple([ATenAxis(size=_const(size), stride=_const(stride), offset=_const(0), incf=_const(1)) for (size, stride) in zip(shape, strides, strict=True)]), # type: ignore dtype=dtype, ) diff --git a/caten/tensor.py b/caten/tensor.py index 29182435..c3144bbd 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -89,72 +89,75 @@ def _broadcast_to(self, new_shape: tuple[ir.ATenOp, ...]) -> Self: ret = Tensor(op=ir.View.expand(self.op, new_shape)) # type: ignore return reshaped if ir.ATenOp.equals(ret.shape, reshaped.shape) else ret # type: ignore @ATen.top - def reshape(self, shape, *args) -> Self: + def reshape(self, shape: tuple[Union[int, ir.ATenOp], ...], *args: Any) -> Self: new_shape = tuple([s if s is not None else self.shape[i] for i, s in enumerate(argfix(shape, *args))]) if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}") - if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if ir.ATenOp.eql(s, -1) else s for s in new_shape]) + if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if ir.ATenOp.eql(s, -1) else s for s in new_shape]) # type: ignore if not ir.ATenOp.eql(prod(self.shape), prod(new_shape)): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})") - ret = Tensor(op=ir.View.reshape(self.op, [ATen.wrap_const(s, dtype=index) for s in new_shape])) - return self if ir.ATenOp.equals(ret.shape, self.shape) else ret + ret = Tensor(op=ir.View.reshape(self.op, [ATen.wrap_const(s, dtype=index) for s in new_shape])) # type: ignore + return self if ir.ATenOp.equals(ret.shape, self.shape) else ret # type: ignore @ATen.top def shrink(self, arg: tuple[tuple[int, int] | None, ...]) -> Self: raise NotImplementedError("shrink todo") @ATen.top - def permute(self, order, *args) -> Self: + def permute(self, order: tuple[int, ...], *args: Any) -> Self: order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args)) if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}") - return Tensor(op=ir.View.permute(self.op, order_arg)) if order_arg != tuple(range(self.ndim)) else self + return Tensor(op=ir.View.permute(self.op, order_arg)) if order_arg != tuple(range(self.ndim)) else self # type: ignore @ATen.top - def expand(self, shape, *args) -> Self: + def expand(self, shape: tuple[Union[int, ir.ATenOp], ...], *args: Any) -> Self: new_shape = tuple(from_ if ir.ATenOp.eql(to, -1) or to is None else to for from_, to in zip(*(align_left(self.shape, argfix(shape, *args))), strict=True)) - return self._broadcast_to([ATen.wrap_const(s, dtype=index) for s in new_shape]) + return self._broadcast_to([ATen.wrap_const(s, dtype=index) for s in new_shape]) # type: ignore ## arithmetic mixin class ATenArith(): def _broadcasted(self, y:Tensor|int|float, reverse:bool=False) -> tuple[Tensor, Tensor]: - x: ATen = self + x = self + assert isinstance(x, Tensor) if not isinstance(y, Tensor): - y = Tensor.const(y, dtype=x.dtype) - if x.dtype != y.dtype: + y = Tensor.const(y, dtype=x.dtype) # type: ignore + if x.dtype != y.dtype: # type: ignore raise TypeError("Cannot add x and y (dtypes mismatch, todo)") - if reverse: x, y = y, x + if reverse: x, y = y, x # type: ignore # compute the output shape - def _broadcast_shape(*shapes:tuple[int, ...]) -> tuple[int, ...]: - def smax(a, b): + def _broadcast_shape(*shapes:tuple[int|ir.ATenOp, ...]) -> tuple[int|ir.ATenOp, ...]: + def smax(a: int|ir.ATenOp, b: int|ir.ATenOp) -> int|ir.ATenOp: if ir.ATenOp.eql(a, 1): return b elif ir.ATenOp.eql(b, 1): return a else: assert ir.ATenOp.eql(a, b) return a # a != b is asserted here? - return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*align_left(*shapes), strict=True)) + return tuple(0 if 0 in nth_dim_sizes else smax(*nth_dim_sizes) for nth_dim_sizes in zip(*align_left(*shapes), strict=True)) + assert isinstance(x, Tensor) and isinstance(y, Tensor) out_shape = _broadcast_shape(x.shape, y.shape) - return x._broadcast_to(out_shape), y._broadcast_to(out_shape) + return x._broadcast_to(out_shape), y._broadcast_to(out_shape) # type: ignore # TODO: # - reduce option # - ir.Add.new (or binop) can have reduce option @ATen.top - def add(self, other, reverse:bool=False): return self.forward(ir.Add, tuple(self._broadcasted(self, other, reverse=reverse))) + def add(self, other, reverse:bool=False) -> Self: return self.forward(ir.Add, tuple(self._broadcasted(self, other, reverse=reverse))) # type: ignore @ATen.top - def mul(self, other, reverse:bool=False): return self.forward(ir.Mul, tuple(self._broadcasted(self, other, reverse=reverse))) - def __eq__(self, other: Any): - print("A") - pass - def __neq__(self, other: Any): - print("B") - pass - def __add__(self, other: Any): return self.add(other) - def __radd__(self, other: Any): return self.add(other, reverse=True) - def __mul__(self, other: Any): return self.mul(other) - def __rmul__(self, other: Any): return self.mul(other, reverse=True) + def mul(self, other, reverse:bool=False) -> Self: return self.forward(ir.Mul, tuple(self._broadcasted(self, other, reverse=reverse))) # type: ignore + @ATen.top + def idiv(self, other, reverse:bool=False) -> Self: return self.forward(ir.IDiv, tuple(self._broadcasted(self, other, reverse=reverse))) # type: ignore + + #def __eq__(self, other: Any): pass + # def __neq__(self, other: Any): pass + def __add__(self, other: Any) -> Self: return self.add(other) # type: ignore + def __radd__(self, other: Any) -> Self: return self.add(other, reverse=True) # type: ignore + def __mul__(self, other: Any) -> Self: return self.mul(other) # type: ignore + def __rmul__(self, other: Any) -> Self: return self.mul(other, reverse=True) # type: ignore + def __floordiv__(self, other: Any) -> Self: return self.idiv(other) # type: ignore + ## math mixin class ATenMath(): @ATen.top - def sin(self: ATen): return self.forward(ir.Sin, self) + def sin(self: ATen): return self.forward(ir.Sin, self) # type: ignore @ATen.top - def cos(self: ATen): return self.forward(ir.Sin, self + Tensor.const(0.0, dtype=self.dtype)) + def cos(self: ATen): return self.forward(ir.Sin, self + Tensor.const(0.0, dtype=self.dtype)) # type: ignore ## nn ops mixin class ATenNN(): pass @@ -168,40 +171,40 @@ class Facet(): pass ## abstraction over backends class ATenBase(ATen, ATenMovements, ATenArith,ATenMath, ATenNN, ATenLinalg, metaclass=ABCMeta): - def __init__(self, *args, op=None): - self.op = op + def __init__(self, *args: Any, op: Union[None, ir.ATenOp]=None): + if op is not None: self.op = op ## == AbstractionLayer @staticmethod - def register(device_id: str, cls: ClassVar): + def register(device_id: str, cls: Any) -> None: DEVICE_TO_TENSOR[device_id] = cls @abstractmethod - def allocate(self): + def allocate(self) -> None: pass @abstractmethod - def free(self): + def free(self) -> None: pass @abstractmethod - def compile(self): + def compile(self) -> None: pass @staticmethod @abstractmethod - def render(op): + def render(op: Any) -> None: pass class Tensor(ATenBase): - def __new__(cls, *args, **kwargs): + def __new__(cls: Any, *args: Any, **kwargs: Any) -> Any: impl = DEVICE_TO_TENSOR.get(get_backend()) if impl is None: raise ValueError(f"Unknown BACKEND={get_backend()}") return impl(*args, **kwargs) ## == [Loop-For Style Frontend IR Specs] ====================================== def kernel(get_kernel: bool = False) -> Callable: def decorator(func: Callable) -> Callable: - pass + return func return decorator # how to generate polyhedral model from tensor ops? # rangeify -> range/when ==> polyhedral model From 8730f7c89cbf6789b0e00305d77ea544626cedd5 Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 22:54:38 +0900 Subject: [PATCH 52/53] helpers --- caten/dtype.py | 3 +-- caten/helpers.py | 2 +- caten/ir.py | 2 +- caten/tensor.py | 2 +- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/caten/dtype.py b/caten/dtype.py index b768dca2..9072b046 100644 --- a/caten/dtype.py +++ b/caten/dtype.py @@ -1,10 +1,9 @@ from __future__ import annotations from dataclasses import dataclass - - from typing import Any + class DTypeMetaClass(type): dcache: dict[tuple, DType] = {} def __call__(cls, *args: Any, **kwargs: Any) -> DType: diff --git a/caten/helpers.py b/caten/helpers.py index cfb41c35..5cecc491 100644 --- a/caten/helpers.py +++ b/caten/helpers.py @@ -2,7 +2,7 @@ import functools import operator -from typing import Iterable, TypeVar, Any +from typing import Any, Iterable, TypeVar T = TypeVar("T") diff --git a/caten/ir.py b/caten/ir.py index 2521e4cf..e0d1dfbd 100644 --- a/caten/ir.py +++ b/caten/ir.py @@ -6,7 +6,7 @@ import operator import weakref from dataclasses import dataclass -from typing import Any, Dict, List, Union, Sequence +from typing import Any, Dict, List, Union from .dtype import DType, index diff --git a/caten/tensor.py b/caten/tensor.py index c3144bbd..3ea242af 100644 --- a/caten/tensor.py +++ b/caten/tensor.py @@ -2,7 +2,7 @@ import os from abc import ABCMeta, abstractmethod -from typing import Any, Callable, ClassVar, List, Self, Tuple, Union +from typing import Any, Callable, List, Self, Tuple, Union import caten.ir as ir from caten.helpers import align_left, argfix, prod From 1cb1d463c8a79af68ec5be7b4a4aa18ecdcc3dff Mon Sep 17 00:00:00 2001 From: hikettei Date: Thu, 25 Dec 2025 22:55:57 +0900 Subject: [PATCH 53/53] helpers --- caten/runtime/cpu.py | 1 + 1 file changed, 1 insertion(+) diff --git a/caten/runtime/cpu.py b/caten/runtime/cpu.py index c6130328..1b544b15 100644 --- a/caten/runtime/cpu.py +++ b/caten/runtime/cpu.py @@ -1,4 +1,5 @@ from typing import Any + import caten as C