-
Notifications
You must be signed in to change notification settings - Fork 1
Feat: Tensor, DType, PatternMatcher, etc #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
101b88b
9578c37
fec84c0
7856871
2d4c0ce
aa350ec
d800a8e
51de548
557dc10
2ab506b
673906c
e4ddd5e
3ac5d96
e66e626
e4bacfd
a3056eb
e259924
227b779
19db18b
74471c9
a5c0b82
7e60cc0
71fc714
704b6be
0b526ce
62de279
f679b50
bc0f4a1
7470551
d66007c
8d68192
5820b67
fa3ae24
d43c14b
6437e69
f35442a
fc442a6
3061ab2
d6e83da
10b0a41
b0a7749
c34c90a
51abfad
6215735
e72883f
50372b0
77c9157
4700274
a2ce359
dec32ae
8c5e351
8730f7c
1cb1d46
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| from . import dtype, helpers, ir, tensor | ||
| from .tensor import ATenSpec, ATen, ATenMath, ATenMovements, ATenNN, ATenLinalg, ATenBase, get_backend, Tensor | ||
| from .runtime import cpu | ||
|
|
||
| __all__ = [ | ||
| "dtype", | ||
| "helpers", | ||
| "ir", | ||
| "tensor" | ||
| ] | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,25 @@ | ||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||
| from dataclasses import dataclass, fields | ||||||||||||||||||
|
||||||||||||||||||
|
|
||||||||||||||||||
| class DTypeMetaClass(type): | ||||||||||||||||||
| dcache: dict[tuple, DType] = {} | ||||||||||||||||||
| def __call__(cls, *args, **kwargs): | ||||||||||||||||||
| if (ret:=DTypeMetaClass.dcache.get(args, None)) is not None: return ret | ||||||||||||||||||
| DTypeMetaClass.dcache[args] = ret = super().__call__(*args) | ||||||||||||||||||
| return ret | ||||||||||||||||||
|
|
||||||||||||||||||
| # TODO: Vector/Packed DType | ||||||||||||||||||
| @dataclass(frozen=True, eq=False) | ||||||||||||||||||
| class DType: | ||||||||||||||||||
| name: str | ||||||||||||||||||
| @staticmethod | ||||||||||||||||||
| def new(name:str): return DType(name) | ||||||||||||||||||
|
||||||||||||||||||
| @dataclass(frozen=True, eq=False) | |
| class DType: | |
| name: str | |
| @staticmethod | |
| def new(name:str): return DType(name) | |
| @dataclass(frozen=True) | |
| class DType(metaclass=DTypeMetaClass): | |
| name: str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following the change to use the metaclass for DType, you should now instantiate dtypes by calling DType(...) directly, which will leverage the caching mechanism.
| float64 = DType.new("float64") | |
| float32 = DType.new("float32") | |
| int64 = DType.new("int64") | |
| int32 = DType.new("int32") | |
| float64 = DType("float64") | |
| float32 = DType("float32") | |
| int64 = DType("int64") | |
| int32 = DType("int32") |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,167 @@ | ||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| from abc import ABCMeta, abstractmethod | ||||||||||||||||||||||||
| from typing import List, Dict, Any | ||||||||||||||||||||||||
| from dataclasses import dataclass | ||||||||||||||||||||||||
| from .dtype import DType | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @dataclass(frozen=True) | ||||||||||||||||||||||||
| class ATenAxis(): | ||||||||||||||||||||||||
| shape: ATenOp | ||||||||||||||||||||||||
| stride: ATenOp | ||||||||||||||||||||||||
| offset: ATenOp | ||||||||||||||||||||||||
| incf: ATenOp | ||||||||||||||||||||||||
| def index(self, i: ATenOp): | ||||||||||||||||||||||||
| # TODO: Assert i.T.dtype is dtype.index | ||||||||||||||||||||||||
| return Mul(self.stride, Add(Mul(i, self.incf), self.offset)) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @dataclass(frozen=True) | ||||||||||||||||||||||||
| class ATenOpType(): | ||||||||||||||||||||||||
| shape: List[ATenAxis] | ||||||||||||||||||||||||
| dtype: DType | ||||||||||||||||||||||||
| offset: ATenOp | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @dataclass(frozen=True) | ||||||||||||||||||||||||
| class ATenOp(metaclass=ABCMeta): | ||||||||||||||||||||||||
| args: List[AtenOp] | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| T: ATenOpType | ||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||
| @abstractmethod | ||||||||||||||||||||||||
| def from_astexpr(cls): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @abstractmethod | ||||||||||||||||||||||||
| def infer_dtype(self): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| ## == Tensor Graph ============================================================ | ||||||||||||||||||||||||
| class UnaryOps(): | ||||||||||||||||||||||||
| def verify(self): verify_tensor_op(self, 1) | ||||||||||||||||||||||||
| class BinaryOps(): | ||||||||||||||||||||||||
| def verify(self): verify_tensor_op(self, 2) | ||||||||||||||||||||||||
| class TernaryOps(): | ||||||||||||||||||||||||
| def verify(self): verify_tensor_op(self, 3) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| ### UnaryOps | ||||||||||||||||||||||||
| class Neg(ATenOp, UnaryOps): | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| OUT = -X | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class Recip(ATenOp, UnaryOps): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class Sin(ATenOp, UnaryOps): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class Exp2(ATenOp, UnaryOps): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class Log2(ATenOp, UnaryOps): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class Sqrt(ATenOp, UnaryOps): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class Cast(ATenOp, UnaryOps): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class Bitcast(ATenOp, UnaryOps): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class Not(ATenOp, UnaryOps): | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| Logical not if the X is a boolean | ||||||||||||||||||||||||
| otherwise lognot ~x | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
| ### BinaryOps | ||||||||||||||||||||||||
| class Add(ATenOp, BinaryOps): | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| OUT = Add(X, Y) | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||
| def from_ast_expr(cls): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class Mul(ATenOp, BinaryOps): | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| OUT = Mul(X, Y) | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||
| def from_ast_expr(cls): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class IDiv(ATenOp, BinaryOps): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class And(ATenOp, BinaryOps): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class Or(ATenOp, BinaryOps): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class And(ATenOp, BinaryOps): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class Xor(ATenOp, BinaryOps): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class Max(ATenOp, BinaryOps): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class Mod(ATenOp, BinaryOps): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class Neq(ATenOp, BinaryOps): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class Lt(ATenOp, BinaryOps): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
| ### TernaryOps | ||||||||||||||||||||||||
| class Where(ATenOp, TernaryOps): | ||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| ### Allocation | ||||||||||||||||||||||||
| class Variable(ATenOp): | ||||||||||||||||||||||||
| symbol: str | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| class Variable(ATenOp): | |
| symbol: str | |
| @dataclass(frozen=True) | |
| class Variable(ATenOp): | |
| symbol: str |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reduce inherits from the ATenOp dataclass but is not a dataclass itself. This means op is only a class annotation, not an instance field. To fix this, you should make Reduce a dataclass.
| class Reduce(ATenOp): | |
| """ | |
| OUT = Reduce(A, B, op=BinaryOps) | |
| """ | |
| op: BinaryOps | |
| @dataclass(frozen=True) | |
| class Reduce(ATenOp): | |
| """ | |
| OUT = Reduce(A, B, op=BinaryOps) | |
| """ | |
| op: BinaryOps |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| import caten as C | ||
|
|
||
| class CPUTensor(C.ATenBase): | ||
| def allocate(self): | ||
| pass | ||
|
|
||
| def free(self): | ||
| pass | ||
|
|
||
| def compile(self): | ||
| pass | ||
|
|
||
| C.ATenBase.register("CPU", CPUTensor) |
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,89 @@ | ||||||||||||
| from __future__ import annotations | ||||||||||||
| from abc import ABCMeta, abstractmethod | ||||||||||||
| from typing import Any, Optional, Tuple, Union, ClassVar | ||||||||||||
| import os | ||||||||||||
|
||||||||||||
| from typing import Any, Optional, Tuple, Union, ClassVar | |
| import os | |
| from typing import Any, Optional, Tuple, Union, ClassVar, List, Callable | |
| from dataclasses import dataclass | |
| import os |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ATen is instantiated as if it were a dataclass in the apply method, but it's not defined as one. This will cause a TypeError. You should decorate it with @dataclass. Additionally, ATenOp is not defined in this scope and should be referenced as ir.ATenOp.
| class ATen: | |
| op: ATenOp # ATen is just a wrapper for ATenOp | |
| @dataclass(frozen=True) | |
| class ATen: | |
| op: ir.ATenOp # ATen is just a wrapper for ATenOp |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hint *args: List is invalid syntax. To type variadic positional arguments, you can use *args: Any. It's also good practice to type **kwargs as **kwargs: Any.
| def apply(self, op: Callable, *args: List, **kwargs) -> ATen: return ATen(op=op(*args, **kwargs)) | |
| def apply(self, op: Callable, *args: Any, **kwargs: Any) -> ATen: return ATen(op=op(*args, **kwargs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For efficiency and clarity, you can store the result of get_backend() in a variable to avoid calling it twice.
| impl = DEVICE_TO_TENSOR.get(get_backend()) | |
| if impl is None: raise ValueError(f"Unknown BACKEND={get_backend()}") | |
| backend = get_backend() | |
| impl = DEVICE_TO_TENSOR.get(backend) | |
| if impl is None: raise ValueError(f"Unknown BACKEND={backend}") |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| import caten as C | ||
|
|
||
| def test_tensor(): | ||
| print(C.Tensor()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
__all__list is missing key public APIs likeTensor. To provide a better user experience, it's recommended to export the main classes and functions that users are expected to use directly. This makes them discoverable and allows for imports likefrom caten import Tensor.