Skip to content
33 changes: 33 additions & 0 deletions caten/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from .kernel import (
DType,
Tensor,
TensorSpec,
f32,
float32,
i32,
int32,
kernel,
parallel,
range,
unroll,
vars,
vectorize,
when,
)

__all__ = [
"vars",
"range",
"kernel",
"Tensor",
"TensorSpec",
"float32",
"int32",
"f32",
"i32",
"DType",
"when",
"parallel",
"vectorize",
"unroll",
]
Comment on lines +18 to +33

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __all__ list is not alphabetically sorted. Sorting it improves readability and maintainability, making it easier to see what is exported from the module.

Suggested change
__all__ = [
"vars",
"range",
"kernel",
"Tensor",
"TensorSpec",
"float32",
"int32",
"f32",
"i32",
"DType",
"when",
"parallel",
"vectorize",
"unroll",
]
__all__ = [
"DType",
"Tensor",
"TensorSpec",
"f32",
"float32",
"i32",
"int32",
"kernel",
"parallel",
"range",
"unroll",
"vars",
"vectorize",
"when",
]

Comment on lines +18 to +33

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better maintainability and easier navigation, it's a good practice to keep the __all__ list sorted alphabetically.

Suggested change
__all__ = [
"vars",
"range",
"kernel",
"Tensor",
"TensorSpec",
"float32",
"int32",
"f32",
"i32",
"DType",
"when",
"parallel",
"vectorize",
"unroll",
]
__all__ = [
"DType",
"Tensor",
"TensorSpec",
"f32",
"float32",
"i32",
"int32",
"kernel",
"parallel",
"range",
"unroll",
"vars",
"vectorize",
"when",
]

38 changes: 35 additions & 3 deletions caten/isl/specs/ast_node_list.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

from ctypes import c_int
from typing import TYPE_CHECKING, Any

from ..ffi import load_libisl
from ..func import ISLFunction
from ..mixin import ISLObjectMixin
from ..obj import ISLObject
from ..qualifier import Give, Take
from ..qualifier import Give, Keep, Param, Take
from ..registry import register_type

if TYPE_CHECKING:
Expand All @@ -17,23 +18,54 @@
class AstNodeList(ISLObject, ISLObjectMixin):
__slots__ = ()

def __init__(self, handle: Any = None) -> None:
super().__init__(handle)
def __init__(self, handle_or_spec: Any = None) -> None:
super().__init__(handle_or_spec)

def copy_handle(self) -> Any:
raise NotImplementedError(f"{type(self).__name__} does not support copy.")

@classmethod
def free_handle(cls, handle: Any) -> None:
_lib.isl_ast_node_list_free(handle)

@classmethod
def from_node(cls, node: "ASTNode") -> "AstNodeList":
return _isl_ast_node_list_from_ast_node(node)

def n_ast_node(self) -> int:
return _isl_ast_node_list_n_ast_node(self)

def get_ast_node(self, index: int) -> "ASTNode":
return _isl_ast_node_list_get_ast_node(self, index)


register_type("AstNodeList", AstNodeList)

_isl_ast_node_list_free = ISLFunction.create(
"isl_ast_node_list_free",
Take("AstNodeList"),
return_=Give("AstNodeList"),
lib=_lib,
)

_isl_ast_node_list_from_ast_node = ISLFunction.create(
"isl_ast_node_list_from_ast_node",
Take("ASTNode"),
return_=Give("AstNodeList"),
lib=_lib,
)

_isl_ast_node_list_n_ast_node = ISLFunction.create(
"isl_ast_node_list_n_ast_node",
Keep("AstNodeList"),
return_=Param(int, ctype=c_int),
lib=_lib,
)

_isl_ast_node_list_get_ast_node = ISLFunction.create(
"isl_ast_node_list_get_ast_node",
Keep("AstNodeList"),
Param(int, ctype=c_int),
return_=Give("ASTNode"),
lib=_lib,
)
175 changes: 175 additions & 0 deletions caten/kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from __future__ import annotations

import inspect
import os
from functools import wraps
from typing import Any, Callable, List, Tuple, Union

from .ops import BinaryOps, ControlOps, MetaOps, Node
from .tensor import DType, Tensor, TensorSpec, f32, float32, i32, int32
from .trace import get_builder


# --- Symbols ---
class Symbol:
def __init__(self, name: str): self.name = name
def __repr__(self) -> str: return self.name

def __lt__(self, other: Any) -> Node:
from .ops import _to_node as ops_to_node
self_node = Node(MetaOps.VAR, (), arg=self)
other_node = ops_to_node(other)
return Node(BinaryOps.LT, (self_node, other_node))

def vars(names: str) -> Tuple[Symbol, ...]:
return tuple(Symbol(n) for n in names.split())

# --- Directives ---
class Directive:
def __init__(self, name: str, args: Tuple[Any, ...] = ()):
self.name = name
self.args = args
def __repr__(self) -> str: return f"Directive({self.name})"

def parallel() -> Directive: return Directive("parallel")
def vectorize(width: int = 4) -> Directive: return Directive("vectorize", (width,))
def unroll(factor: int = 4) -> Directive: return Directive("unroll", (factor,))

# --- Range ---
_range_counter = 0

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The global _range_counter is not thread-safe. If two kernels are traced in parallel threads, they will share and modify this counter, leading to a race condition. This should be moved into the GraphBuilder class in caten/trace.py to make it thread-local. This is the first of several changes to address this.


class RangeContext:
def __init__(self, *args: Union[int, Symbol]):
global _range_counter
self.args = args
self.iter_sym = Symbol(f"i{_range_counter}")
self.directives: List[Directive] = []
_range_counter += 1
Comment on lines +43 to +47

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

To fix the thread-safety issue with _range_counter, this logic should be updated to use a counter from the GraphBuilder instance instead of a global variable. This change depends on adding range_counter to the GraphBuilder class.

Suggested change
global _range_counter
self.args = args
self.iter_sym = Symbol(f"i{_range_counter}")
self.directives: List[Directive] = []
_range_counter += 1
builder = get_builder()
self.args = args
self.iter_sym = Symbol(f"i{builder.range_counter}")
self.directives: List[Directive] = []
builder.range_counter += 1

Comment on lines +43 to +47

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using a global counter _range_counter for generating unique symbol names is not thread-safe. If kernel compilation were to be parallelized in the future, this could lead to race conditions and non-unique symbol names. Consider using a thread-local counter or passing a context object through the compilation pipeline to manage state like this.


def __or__(self, other: Directive) -> 'RangeContext':
self.directives.append(other)
return self

def __enter__(self) -> Symbol:
get_builder().push_block()
return self.iter_sym

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
body_block = get_builder().pop_block()
# arg structure: (iter_sym, bounds, body, directives)
node = Node(ControlOps.RANGE, (), arg=(self.iter_sym, self.args, body_block, self.directives), name=self.iter_sym.name)
get_builder().push(node)

def range(*args: Union[int, Symbol]) -> RangeContext:
return RangeContext(*args)

# --- Control Flow ---
class WhenContext:
def __init__(self, cond: Any):
self.cond = cond

def __enter__(self) -> None:
get_builder().push_block()

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
body_block = get_builder().pop_block()
# arg structure: (cond, then_block, else_block)
# For now else_block is empty
node = Node(ControlOps.IF, (), arg=(self.cond, body_block, []))
get_builder().push(node)

def when(cond: Any) -> WhenContext:
return WhenContext(cond)

# --- Kernel ---
class Kernel:
def __init__(self, compiled_kernel: Any, graph: List[Node]):
self.compiled_kernel = compiled_kernel
self.graph = graph

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.compiled_kernel(*args, **kwargs)

def print_graph(self) -> None:
print("--- Execution Graph ---")
self._print_block(self.graph, 0)

def _print_block(self, block: List[Node], indent: int) -> None:
prefix = " " * indent
for node in block:
print(f"{prefix}{node}")
if node.op == ControlOps.RANGE:
# (iter_sym, bounds, body, directives)
directives = node.arg[3]
if directives:
print(f"{prefix} Directives: {directives}")
print(f"{prefix} Body:")
self._print_block(node.arg[2], indent + 2)
elif node.op == ControlOps.IF:
print(f"{prefix} Then:")
self._print_block(node.arg[1], indent + 2)
if node.arg[2]:
print(f"{prefix} Else:")
self._print_block(node.arg[2], indent + 2)

def kernel(get_kernel: bool = False) -> Callable:
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
# 1. Setup
global _range_counter
_range_counter = 0
Comment on lines +120 to +121

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

As part of making the range counter thread-safe, this global reset should be removed. The counter should be part of GraphBuilder and reset in GraphBuilder.reset().

builder = get_builder()
builder.reset()

# 2. Create Placeholders
sig = inspect.signature(func)
func_args = []

if args:
for arg in args:
if isinstance(arg, Tensor):
func_args.append(arg)
if arg.node.op == MetaOps.PLACEHOLDER:
if arg.node not in builder.inputs:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Checking for an item's existence in a list (builder.inputs) is an O(n) operation. For kernels with many input tensors, this could become a minor performance bottleneck during tracing. Using a set for builder.inputs would make this check an O(1) operation on average.

builder.register_input(arg.node)
else:
func_args.append(arg)
Comment on lines +129 to +137

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for creating placeholder nodes for function arguments only handles positional arguments (args). If tensors are passed as keyword arguments, they will be missed, leading to incorrect tracing. The argument handling should be updated to correctly process both positional and keyword arguments, perhaps by using inspect.signature.bind to map all provided arguments to their corresponding parameters.

else:
for name, param in sig.parameters.items():
if isinstance(param.annotation, TensorSpec):
node = Node(MetaOps.PLACEHOLDER, (), arg=param.annotation, name=name)
builder.register_input(node)
func_args.append(Tensor(node))

# 3. Execute Function (Tracing)
_ = func(*func_args)

# 4. Finalize Graph
full_graph = builder.root_block
Comment on lines +146 to +149

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current tracing mechanism does not identify the output nodes of the kernel. The return value of the traced function is discarded (_ = func(*func_args)), and the entire root_block is passed to the compiler. This prevents dead code elimination (via the unimplemented resolve_graph) and means many unnecessary intermediate nodes are processed.

A mechanism to designate outputs is needed. For example, the kernel function's return statement could be used to identify the output tensor(s). The wrapper would then capture this return value and use it to determine the actual output nodes for the graph.


# 5. Compile
runtime_name = os.environ.get("RUNTIME", "CLANG")
if runtime_name == "CLANG":
from .runtimes.clang import ClangRuntime
runtime = ClangRuntime()
else:
raise NotImplementedError(f"Runtime {runtime_name} not supported")

compiled = runtime.compile(full_graph, builder.inputs)

k_obj = Kernel(compiled, full_graph)

if get_kernel:
return k_obj
return k_obj(*args)

return wrapper
return decorator

__all__ = [
"vars", "range", "when", "parallel", "vectorize", "unroll",
"kernel", "Tensor", "TensorSpec",
"float32", "int32", "f32", "i32",
"DType"
]
Loading
Loading