Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
101b88b
wip
hikettei Dec 24, 2025
9578c37
concepts
hikettei Dec 24, 2025
fec84c0
wip
hikettei Dec 24, 2025
7856871
wip
hikettei Dec 24, 2025
2d4c0ce
wip
hikettei Dec 24, 2025
aa350ec
facet ideas
hikettei Dec 24, 2025
d800a8e
allocation graph
hikettei Dec 24, 2025
51de548
allocation graph
hikettei Dec 24, 2025
557dc10
allocation graph
hikettei Dec 24, 2025
2ab506b
allocation graph
hikettei Dec 24, 2025
673906c
concrete semantics
hikettei Dec 24, 2025
e4ddd5e
dataclass
hikettei Dec 24, 2025
3ac5d96
dataclass
hikettei Dec 24, 2025
e66e626
dataclass
hikettei Dec 24, 2025
e4bacfd
dataclass
hikettei Dec 24, 2025
a3056eb
tuple
hikettei Dec 24, 2025
e259924
concrete semantics
hikettei Dec 25, 2025
227b779
concrete semantics
hikettei Dec 25, 2025
19db18b
concrete semantics
hikettei Dec 25, 2025
74471c9
WIP
hikettei Dec 25, 2025
a5c0b82
some clean ups
hikettei Dec 25, 2025
7e60cc0
just an idea
hikettei Dec 25, 2025
71fc714
feat: reshape
hikettei Dec 25, 2025
704b6be
feat: reshape
hikettei Dec 25, 2025
0b526ce
feat: expand
hikettei Dec 25, 2025
62de279
feat: expand
hikettei Dec 25, 2025
f679b50
Cache, Type Inference, Verification
hikettei Dec 25, 2025
bc0f4a1
Cache, Type Inference, Verification
hikettei Dec 25, 2025
7470551
Cache, Type Inference, Verification
hikettei Dec 25, 2025
d66007c
Cache, Type Inference, Verification
hikettei Dec 25, 2025
8d68192
Feat: Simplifier
hikettei Dec 25, 2025
5820b67
Feat: Simplifier
hikettei Dec 25, 2025
fa3ae24
Feat: Simplifier
hikettei Dec 25, 2025
d43c14b
Reshape symbolic
hikettei Dec 25, 2025
6437e69
docs
hikettei Dec 25, 2025
f35442a
docs
hikettei Dec 25, 2025
fc442a6
smax
hikettei Dec 25, 2025
3061ab2
write specs on def
hikettei Dec 25, 2025
d6e83da
write specs on def
hikettei Dec 25, 2025
10b0a41
Fix scope issue
hikettei Dec 25, 2025
b0a7749
Fix scope issue
hikettei Dec 25, 2025
c34c90a
fix ruff
hikettei Dec 25, 2025
51abfad
fix ruff
hikettei Dec 25, 2025
6215735
fix ruff
hikettei Dec 25, 2025
e72883f
helpers
hikettei Dec 25, 2025
50372b0
helpers
hikettei Dec 25, 2025
77c9157
helpers
hikettei Dec 25, 2025
4700274
helpers
hikettei Dec 25, 2025
a2ce359
helpers
hikettei Dec 25, 2025
dec32ae
helpers
hikettei Dec 25, 2025
8c5e351
helpers
hikettei Dec 25, 2025
8730f7c
helpers
hikettei Dec 25, 2025
1cb1d46
helpers
hikettei Dec 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions caten/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from . import dtype, helpers, ir, tensor
from .tensor import ATenSpec, ATen, ATenMath, ATenMovements, ATenNN, ATenLinalg, ATenBase, get_backend, Tensor
from .runtime import cpu

__all__ = [
"dtype",
"helpers",
"ir",
"tensor"
]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __all__ list is missing key public APIs like Tensor. To provide a better user experience, it's recommended to export the main classes and functions that users are expected to use directly. This makes them discoverable and allows for imports like from caten import Tensor.

Suggested change
__all__ = [
"dtype",
"helpers",
"ir",
"tensor"
]
__all__ = [
"dtype",
"helpers",
"ir",
"tensor",
"Tensor",
"get_backend",
"ATenSpec"
]

25 changes: 25 additions & 0 deletions caten/dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations
from dataclasses import dataclass, fields

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fields function is imported from dataclasses but is not used in this file. It's good practice to remove unused imports to keep the code clean.

Suggested change
from dataclasses import dataclass, fields
from dataclasses import dataclass


class DTypeMetaClass(type):
dcache: dict[tuple, DType] = {}
def __call__(cls, *args, **kwargs):
if (ret:=DTypeMetaClass.dcache.get(args, None)) is not None: return ret
DTypeMetaClass.dcache[args] = ret = super().__call__(*args)
return ret

# TODO: Vector/Packed DType
@dataclass(frozen=True, eq=False)
class DType:
name: str
@staticmethod
def new(name:str): return DType(name)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The DType class does not use the DTypeMetaClass you've defined, which means the caching mechanism is not active. This results in multiple instances for the same dtype, e.g., DType.new('float32') is not DType.new('float32'). To enable caching and ensure dtype singletons, you should apply the metaclass to DType and remove the redundant new static method. Also, with frozen=True, you should let eq be True (the default) to allow for value-based comparison, which will work correctly with the singleton pattern provided by the metaclass.

Suggested change
@dataclass(frozen=True, eq=False)
class DType:
name: str
@staticmethod
def new(name:str): return DType(name)
@dataclass(frozen=True)
class DType(metaclass=DTypeMetaClass):
name: str


## definitions
float64 = DType.new("float64")
float32 = DType.new("float32")
int64 = DType.new("int64")
int32 = DType.new("int32")
Comment on lines 22 to 27

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Following the change to use the metaclass for DType, you should now instantiate dtypes by calling DType(...) directly, which will leverage the caching mechanism.

Suggested change
float64 = DType.new("float64")
float32 = DType.new("float32")
int64 = DType.new("int64")
int32 = DType.new("int32")
float64 = DType("float64")
float32 = DType("float32")
int64 = DType("int64")
int32 = DType("int32")


## dtype aliases
index = int64
Empty file added caten/helpers.py
Empty file.
167 changes: 167 additions & 0 deletions caten/ir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from __future__ import annotations

from abc import ABCMeta, abstractmethod
from typing import List, Dict, Any
from dataclasses import dataclass
from .dtype import DType

@dataclass(frozen=True)
class ATenAxis():
shape: ATenOp
stride: ATenOp
offset: ATenOp
incf: ATenOp
def index(self, i: ATenOp):
# TODO: Assert i.T.dtype is dtype.index
return Mul(self.stride, Add(Mul(i, self.incf), self.offset))

@dataclass(frozen=True)
class ATenOpType():
shape: List[ATenAxis]
dtype: DType
offset: ATenOp

@dataclass(frozen=True)
class ATenOp(metaclass=ABCMeta):
args: List[AtenOp]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a typo in the type hint. AtenOp should be ATenOp to match the class definition.

Suggested change
args: List[AtenOp]
args: List[ATenOp]

T: ATenOpType
@classmethod
@abstractmethod
def from_astexpr(cls):
pass

@abstractmethod
def infer_dtype(self):
pass

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

ATenOp is defined as an abstract base class with abstract methods from_astexpr and infer_dtype. However, none of the concrete subclasses (e.g., Neg, Add, Mul) implement these methods. This will cause a TypeError when you try to instantiate any of these op classes.

You need to either:

  1. Provide concrete implementations for from_astexpr and infer_dtype in each subclass.
  2. Remove the @abstractmethod decorators if they are not meant to be abstract yet.

## == Tensor Graph ============================================================
class UnaryOps():
def verify(self): verify_tensor_op(self, 1)
class BinaryOps():
def verify(self): verify_tensor_op(self, 2)
class TernaryOps():
def verify(self): verify_tensor_op(self, 3)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function verify_tensor_op used in UnaryOps, BinaryOps, and TernaryOps is not defined. This will raise a NameError when any verify() method is called. You should either define this function or remove/comment out these methods if they are placeholders.

### UnaryOps
class Neg(ATenOp, UnaryOps):
"""
OUT = -X
"""
pass

class Recip(ATenOp, UnaryOps):
pass

class Sin(ATenOp, UnaryOps):
pass

class Exp2(ATenOp, UnaryOps):
pass

class Log2(ATenOp, UnaryOps):
pass

class Sqrt(ATenOp, UnaryOps):
pass

class Cast(ATenOp, UnaryOps):
pass

class Bitcast(ATenOp, UnaryOps):
pass

class Not(ATenOp, UnaryOps):
"""
Logical not if the X is a boolean
otherwise lognot ~x
"""
pass
### BinaryOps
class Add(ATenOp, BinaryOps):
"""
OUT = Add(X, Y)
"""
@classmethod
def from_ast_expr(cls):
pass

class Mul(ATenOp, BinaryOps):
"""
OUT = Mul(X, Y)
"""
@classmethod
def from_ast_expr(cls):
pass

class IDiv(ATenOp, BinaryOps):
pass

class And(ATenOp, BinaryOps):
pass

class Or(ATenOp, BinaryOps):
pass

class And(ATenOp, BinaryOps):
pass

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The class And is defined twice in this file (the first definition is on line 97). This will overwrite the first definition. Please remove the duplicate.


class Xor(ATenOp, BinaryOps):
pass

class Max(ATenOp, BinaryOps):
pass

class Mod(ATenOp, BinaryOps):
pass

class Neq(ATenOp, BinaryOps):
pass

class Lt(ATenOp, BinaryOps):
pass
### TernaryOps
class Where(ATenOp, TernaryOps):
pass

### Allocation
class Variable(ATenOp):
symbol: str

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Variable inherits from ATenOp, which is a dataclass, but Variable itself is not. This means symbol is just a class annotation and not an instance field. To make symbol an instance field, you should also make Variable a dataclass.

Suggested change
class Variable(ATenOp):
symbol: str
@dataclass(frozen=True)
class Variable(ATenOp):
symbol: str


class Allocate(ATenOp):
"""
Allocate(S1, S2, S3, ...)
"""
pass

## == JIT =====================================================================
class Reduce(ATenOp):
"""
OUT = Reduce(A, B, op=BinaryOps)
"""
op: BinaryOps

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Reduce inherits from the ATenOp dataclass but is not a dataclass itself. This means op is only a class annotation, not an instance field. To fix this, you should make Reduce a dataclass.

Suggested change
class Reduce(ATenOp):
"""
OUT = Reduce(A, B, op=BinaryOps)
"""
op: BinaryOps
@dataclass(frozen=True)
class Reduce(ATenOp):
"""
OUT = Reduce(A, B, op=BinaryOps)
"""
op: BinaryOps

@classmethod
def from_ast_expr(cls):
pass

class Store(ATenOp):
pass
## ControlFlow
class Range(ATenOp):
pass

class Loop(ATenOp):
pass

class When(ATenOp):
pass

class Progn(ATenOp):
pass
## == ScheduleOps ============================================================
class Polyhedral(ATenOp):
pass

def Var():
pass

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function Var is defined but has an empty body. This should be implemented or removed if it's a placeholder.


# e.g.:
# a = T.Var("A[m n]", float32)
# P.stmt("...")[a]
7 changes: 0 additions & 7 deletions caten/ops.py

This file was deleted.

13 changes: 13 additions & 0 deletions caten/runtime/cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import caten as C

class CPUTensor(C.ATenBase):
def allocate(self):
pass

def free(self):
pass

def compile(self):
pass

C.ATenBase.register("CPU", CPUTensor)
89 changes: 89 additions & 0 deletions caten/tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from __future__ import annotations
from abc import ABCMeta, abstractmethod
from typing import Any, Optional, Tuple, Union, ClassVar
import os

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Some necessary imports are missing. List and Callable should be imported from typing, and dataclass should be imported from dataclasses to be used as a decorator.

Suggested change
from typing import Any, Optional, Tuple, Union, ClassVar
import os
from typing import Any, Optional, Tuple, Union, ClassVar, List, Callable
from dataclasses import dataclass
import os

import caten.ir as ir
# [TODO]
# - Tensor => Fused Tensor Graph Construction
# - Tensor Kernel Construction
# - Then, start working on auto scheduler
## Backend Abstraction
DEVICE_TO_TENSOR = {}
def get_backend(): return os.environ.get("BACKEND", "CPU")
##

class ATenSpec:
"""
Usage: C.Tensor[float32, "M", "N"] -> TensorSpec(M N)
"""
def __init__(self, shape: Tuple[Any, ...], dtype: Any = None):
self.shape = shape
self.dtype = dtype
def __repr__(self) -> str: return f"TensorSpec({self.shape}, {self.dtype})"

class ATen:
op: ATenOp # ATen is just a wrapper for ATenOp

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

ATen is instantiated as if it were a dataclass in the apply method, but it's not defined as one. This will cause a TypeError. You should decorate it with @dataclass. Additionally, ATenOp is not defined in this scope and should be referenced as ir.ATenOp.

Suggested change
class ATen:
op: ATenOp # ATen is just a wrapper for ATenOp
@dataclass(frozen=True)
class ATen:
op: ir.ATenOp # ATen is just a wrapper for ATenOp

@classmethod
def from_shape(cls, shape: List[ATenOp]):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

ATenOp is not defined in this scope. It should be referenced from the ir module as ir.ATenOp.

Suggested change
def from_shape(cls, shape: List[ATenOp]):
def from_shape(cls, shape: List[ir.ATenOp]):

return ir.Allocate(shape) # TODO

def apply(self, op: Callable, *args: List, **kwargs) -> ATen: return ATen(op=op(*args, **kwargs))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The type hint *args: List is invalid syntax. To type variadic positional arguments, you can use *args: Any. It's also good practice to type **kwargs as **kwargs: Any.

Suggested change
def apply(self, op: Callable, *args: List, **kwargs) -> ATen: return ATen(op=op(*args, **kwargs))
def apply(self, op: Callable, *args: Any, **kwargs: Any) -> ATen: return ATen(op=op(*args, **kwargs))


def __class_getitem__(cls, item: Union[Any, Tuple[Any, ...]]) -> TensorSpec:
# Usage: C.Tensor[10, 10] -> TensorSpec((10, 10))
# TODO
pass

def polyhedral(self):
pass

class ATenMath():
pass

class ATenMovements():
pass

class ATenNN():
pass

class ATenLinalg():
pass

class ATenBase(ATen, ATenMath, ATenNN, ATenMovements, ATenLinalg, metaclass=ABCMeta):
## == AbstractionLayer
@staticmethod
def register(device_id: str, cls: ClassVar):
DEVICE_TO_TENSOR[device_id] = cls

@abstractmethod
def allocate(self):
pass

@abstractmethod
def free(self):
pass

@abstractmethod
def compile(self):
pass

class Tensor(ATenBase):
def __new__(cls, *args, **kwargs):
impl = DEVICE_TO_TENSOR.get(get_backend())
if impl is None: raise ValueError(f"Unknown BACKEND={get_backend()}")
Comment on lines +201 to +202

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For efficiency and clarity, you can store the result of get_backend() in a variable to avoid calling it twice.

Suggested change
impl = DEVICE_TO_TENSOR.get(get_backend())
if impl is None: raise ValueError(f"Unknown BACKEND={get_backend()}")
backend = get_backend()
impl = DEVICE_TO_TENSOR.get(backend)
if impl is None: raise ValueError(f"Unknown BACKEND={backend}")

return impl(*args, **kwargs)
## For-Style Graph Construction
def kernel(get_kernel: bool = False) -> Callable:
def decorator(func: Callable) -> Callable:
pass
return decorator

# how to generate polyhedral model from tensor ops?
# rangeify -> range/when ==> polyhedral model
# with C.range(10, 10):
# with C.when(10, 10)
class range():
pass

class when():
pass

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to PEP 8, class names should use PascalCase. Please rename range to Range and when to When. Renaming range also avoids shadowing the built-in range function, which can be a source of confusion.

Suggested change
class range():
pass
class when():
pass
class Range():
pass
class When():
pass

4 changes: 4 additions & 0 deletions test/test_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import caten as C

def test_tensor():
print(C.Tensor())
Loading