Skip to content

Commit f320079

Browse files
authored
Feat: Tensor, DType, PatternMatcher, etc (#21)
* wip * concepts * wip * wip * wip * facet ideas * allocation graph * allocation graph * allocation graph * allocation graph * concrete semantics * dataclass * dataclass * dataclass * dataclass * tuple * concrete semantics * concrete semantics * concrete semantics * WIP * some clean ups * just an idea * feat: reshape * feat: reshape * feat: expand * feat: expand * Cache, Type Inference, Verification * Cache, Type Inference, Verification * Cache, Type Inference, Verification * Cache, Type Inference, Verification * Feat: Simplifier * Feat: Simplifier * Feat: Simplifier * Reshape symbolic * docs * docs * smax * write specs on def * write specs on def * Fix scope issue * Fix scope issue * fix ruff * fix ruff * fix ruff * helpers * helpers * helpers * helpers * helpers * helpers * helpers * helpers * helpers
1 parent ab5e6f1 commit f320079

File tree

13 files changed

+817
-36
lines changed

13 files changed

+817
-36
lines changed

AGENTS.md

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,27 +37,4 @@
3737

3838
## Polyhedral DSL Guidelines
3939
- Prefer using Mixin operator overloads (e.g., `A | B` instead of `A.union(B)`) for cleaner code in user scripts and DSL implementations.
40-
41-
## 作業計画と進捗 (2025-11-16)
42-
直近のギャップ集計: `docs/ISL_missing_apis.md`(2025-11-16 再生成、欠落API 2047件)。map 残 2 件(tuple_name系シンボル未提供のみ、libisl非存在)。
43-
優先順とステータス(✅完了 / 🚧着手中 / ⏳未着手)
44-
- Identifier / Id: 🚧(基本APIは揃うが欠落検証 継続)
45-
- Space / LocalSpace: 🚧(dim/tuple系以外の抜け有り)
46-
- Constraint / Equality-Constraint / Inequality-Constraint: 🚧
47-
- BasicSet / Set: 🚧(missing計: set 105, basic_set 63)
48-
- UnionSet: 🚧(missing計: union_set 52)
49-
- BasicMap / Map: 🚧(missing計: basic_map 85, map 190)
50-
- UnionMap: 🚧(missing計: union_map 112)
51-
- Aff / PwAff / MultiAff / PwMultiAff: 🚧(missing計: aff 73, pw_aff 96, multi_aff 90, pw_multi_aff 89)
52-
- MultiVal: 🚧(missing計: multi_val 37, val 66)
53-
- MultiUnionPwAff / UnionPwAff / UnionPwMultiAff / MultiUnionPwAff: 🚧(missing計: multi_union_pw_aff 75 ほか)
54-
- ScheduleConstraint / Schedule / ScheduleNode: ✅(schedule_node 0)
55-
- UnionAccessInfo / UnionFlow: ⏳
56-
- ASTExpr / ASTNode / ASTBuild: 🚧(Expr系クラス不足・missing計: ast_expr 0, ast_node 0)
57-
- Mat: ✅(要素参照系API実装済・missing計: mat 0)
58-
- その他: misc 71, options 29 など多数。
59-
60-
次に着手する対象:
61-
1) ScheduleNode / ASTExpr / Mat のクラス追加・アクセサ補完
62-
2) UnionAccessInfo / UnionFlow ラッパ実装
63-
3) map / set 系を皮切りに `docs/ISL_missing_apis.md` に基づく欠落API埋め
40+
- Do not write shit code, be respectful to existing codes.

caten/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .dtype import * # noqa: F403, I001
2+
from .simplifier import * # noqa: F403, I001
3+
from .tensor import * # noqa: F403, I001
4+
from .runtime import cpu # noqa: I001, F401

caten/dtype.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Any
5+
6+
7+
class DTypeMetaClass(type):
8+
dcache: dict[tuple, DType] = {}
9+
def __call__(cls, *args: Any, **kwargs: Any) -> DType:
10+
if (ret:=DTypeMetaClass.dcache.get(args, None)) is not None: return ret
11+
DTypeMetaClass.dcache[args] = ret = super().__call__(*args)
12+
return ret
13+
14+
# TODO: Vector/Packed DType
15+
@dataclass(frozen=True, eq=False)
16+
class DType:
17+
name: str
18+
@staticmethod
19+
def new(name:str) -> DType: return DType(name)
20+
21+
## definitions
22+
float64 = DType.new("float64")
23+
float32 = DType.new("float32")
24+
float16 = DType.new("float16")
25+
26+
int64 = DType.new("int64")
27+
int32 = DType.new("int32")
28+
int16 = DType.new("int16")
29+
int8 = DType.new("int8")
30+
uint64 = DType.new("uint64")
31+
uint32 = DType.new("uint32")
32+
uint16 = DType.new("uint16")
33+
uint8 = DType.new("uint8")
34+
35+
## dtype aliases
36+
index = int64
37+
default_float = float32
38+
39+
floats = [float64, float32, float16]
40+
integers = [int64, int32, int16, int8, uint64, uint32, uint16, uint8]

caten/helpers.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from __future__ import annotations
2+
3+
import functools
4+
import operator
5+
from typing import Any, Iterable, TypeVar
6+
7+
T = TypeVar("T")
8+
9+
def prod(x:Iterable[T]) -> T|int: return functools.reduce(operator.mul, x, 1)
10+
11+
def argfix(*x: Any) -> tuple[Any, ...]:
12+
if x and x[0].__class__ in (tuple, list):
13+
if len(x) != 1: raise ValueError(f"bad arg {x}")
14+
return tuple(x[0])
15+
return x
16+
17+
def align_left(*shapes: tuple[Any, ...]) -> tuple[tuple[Any, ...], ...]:
18+
# unsqueeze left to make every shape same length
19+
max_dim = max(len(shape) for shape in shapes)
20+
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)

0 commit comments

Comments
 (0)