Skip to content

Commit 9ae1800

Browse files
Align Numba codegen to the new assembly. E2E Numba test. (#161)
* Align Numba codegen to new assembly. E2E Numba test. * Delete `override` import * Remove `shape` from `BufferizedNDArrayFType` * Remove colon from labels * Remove dangling shapes * Cleanup * Apply review comments * Custom serializer for BufferizedNDArray * Keep only `numba_type` for `BufferizedNDArray` * Remove `construct_from_numba` from `BufferizedNDArray` * Rename function to `provision_tensors` * Move `numba_type` to `AssemblyStructFType` * Moved last serialization part to `AssemblyStructFType` * Use kind and itemsize instead of char for dtype str * move the serialization stuff around a bit (just to keep the dependency order consistent) * pre-commit --------- Co-authored-by: Willow Ahrens <[email protected]>
1 parent 9b43331 commit 9ae1800

26 files changed

+808
-376
lines changed

src/finchlite/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
ExtentFType,
1010
NotationContext,
1111
dimension,
12-
extent,
1312
)
1413
from .galley import (
1514
DC,
@@ -20,6 +19,7 @@
2019
from .interface import (
2120
EagerTensor,
2221
LazyTensor,
22+
Mode,
2323
Scalar,
2424
abs,
2525
acos,
@@ -54,6 +54,7 @@
5454
floordiv,
5555
fuse,
5656
fused,
57+
get_default_scheduler,
5758
greater,
5859
greater_equal,
5960
less,
@@ -82,6 +83,7 @@
8283
pow,
8384
prod,
8485
reduce,
86+
set_default_scheduler,
8587
sin,
8688
sinh,
8789
split_dims,
@@ -121,6 +123,7 @@
121123
"FTyped",
122124
"FiberTensorFType",
123125
"LazyTensor",
126+
"Mode",
124127
"NotationContext",
125128
"NumpyBuffer",
126129
"NumpyBufferFType",
@@ -160,14 +163,14 @@
160163
"elementwise",
161164
"equal",
162165
"expand_dims",
163-
"extent",
164166
"fill_value",
165167
"fisinstance",
166168
"flatten",
167169
"floordiv",
168170
"ftype",
169171
"fuse",
170172
"fused",
173+
"get_default_scheduler",
171174
"greater",
172175
"greater_equal",
173176
"less",
@@ -196,6 +199,7 @@
196199
"pow",
197200
"prod",
198201
"reduce",
202+
"set_default_scheduler",
199203
"shape_type",
200204
"sin",
201205
"sinh",

src/finchlite/algebra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .tensor import (
2525
Tensor,
2626
TensorFType,
27+
TensorPlaceholder,
2728
element_type,
2829
fill_value,
2930
shape_type,
@@ -34,6 +35,7 @@
3435
"StableNumber",
3536
"Tensor",
3637
"TensorFType",
38+
"TensorPlaceholder",
3739
"conjugate",
3840
"conjugate",
3941
"element_type",

src/finchlite/algebra/tensor.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
class TensorFType(FType, ABC):
1111
@property
12-
def ndim(self) -> int:
12+
def ndim(self) -> np.intp:
1313
"""Number of dimensions of the tensor."""
14-
return len(self.shape_type)
14+
return np.intp(len(self.shape_type))
1515

1616
@property
1717
@abstractmethod
@@ -21,13 +21,13 @@ def fill_value(self) -> Any:
2121

2222
@property
2323
@abstractmethod
24-
def element_type(self):
24+
def element_type(self) -> Any:
2525
"""Data type of the tensor elements."""
2626
...
2727

2828
@property
2929
@abstractmethod
30-
def shape_type(self) -> tuple:
30+
def shape_type(self) -> tuple[type, ...]:
3131
"""Shape type of the tensor. The shape type is a tuple of the index
3232
types in the tensor. It's the type of each element in tns.shape. It
3333
should be an actual tuple, rather than a tuple type, so that it can hold
@@ -46,16 +46,10 @@ class Tensor(FTyped, ABC):
4646
"""
4747

4848
@property
49-
def ndim(self) -> int:
49+
def ndim(self) -> np.intp:
5050
"""Number of dimensions of the tensor."""
5151
return self.ftype.ndim
5252

53-
@property
54-
@abstractmethod
55-
def shape(self):
56-
"""Shape of the tensor as a tuple."""
57-
...
58-
5953
@property
6054
@abstractmethod
6155
def ftype(self) -> TensorFType:
@@ -143,7 +137,7 @@ class NDArrayFType(TensorFType):
143137
This includes the fill value, element type, and shape type.
144138
"""
145139

146-
def __init__(self, dtype: np.dtype, ndim: int):
140+
def __init__(self, dtype: np.dtype, ndim: np.intp):
147141
self._dtype = dtype
148142
self._ndim = ndim
149143

@@ -159,7 +153,7 @@ def __repr__(self) -> str:
159153
return f"NDArrayFType(dtype={repr(self._dtype)}, ndim={self._ndim})"
160154

161155
@property
162-
def ndim(self) -> int:
156+
def ndim(self) -> np.intp:
163157
return self._ndim
164158

165159
@property
@@ -176,5 +170,14 @@ def shape_type(self) -> tuple:
176170

177171

178172
register_property(
179-
np.ndarray, "ftype", "__attr__", lambda x: NDArrayFType(x.dtype, x.ndim)
173+
np.ndarray, "ftype", "__attr__", lambda x: NDArrayFType(x.dtype, np.intp(x.ndim))
180174
)
175+
176+
177+
class TensorPlaceholder:
178+
def __init__(self, dtype):
179+
self._dtype = dtype
180+
181+
@property
182+
def dtype(self):
183+
return self._dtype

src/finchlite/autoschedule/_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
1-
from collections.abc import Iterable
1+
from collections.abc import Iterable, Sequence
22

33

44
def intersect(x1: Iterable, x2: Iterable) -> tuple:
55
return tuple(x for x in x1 if x in x2)
66

77

8-
def is_subsequence(x1: tuple, x2: tuple) -> bool:
9-
return x1 == tuple(x for x in x2 if x in x1)
8+
def extend_uniqe(x1: Iterable, x2: Iterable) -> tuple:
9+
return tuple(x1) + setdiff(x2, x1)
1010

1111

12-
def setdiff(x1: tuple, x2: tuple) -> tuple:
12+
def is_subsequence(x1: Iterable, x2: Iterable) -> bool:
13+
return tuple(x1) == tuple(x for x in x2 if x in x1)
14+
15+
16+
def setdiff(x1: Iterable, x2: Iterable) -> tuple:
1317
return tuple([x for x in x1 if x not in x2])
1418

1519

16-
def with_subsequence(x1: tuple, x2: tuple) -> tuple:
20+
def with_subsequence(x1: Sequence, x2: Iterable) -> tuple:
1721
res = list(x2)
1822
indices = [idx for idx, val in enumerate(x2) if val in x1]
1923
for idx, i in enumerate(indices):

0 commit comments

Comments
 (0)