Skip to content

Commit ca5db23

Browse files
committed
Remove BufferizedNDArray mentions from autoschedule.compiler
1 parent b45493b commit ca5db23

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

src/finchlite/algebra/tensor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ def shape_type(self) -> tuple[type, ...]:
3434
e.g. dtypes, formats, or types, and so that we can easily index it."""
3535
...
3636

37+
@abstractmethod
38+
def __init__(self, *args):
39+
"""TensorFType instance initializer."""
40+
...
41+
3742

3843
class Tensor(FTyped, ABC):
3944
"""

src/finchlite/autoschedule/compiler.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
return_type,
1515
)
1616
from ..codegen import NumpyBufferFType
17-
from ..compile import BufferizedNDArrayFType, ExtentFType, dimension
17+
from ..compile import ExtentFType, dimension
1818
from ..finch_assembly import TupleFType
1919
from ..finch_logic import (
2020
Aggregate,
@@ -204,7 +204,7 @@ def __call__(
204204
return ntn.Assign(
205205
ntn.Variable(
206206
name,
207-
BufferizedNDArrayFType(
207+
type(val)(
208208
NumpyBufferFType(val.dtype),
209209
val.ndim,
210210
TupleFType.from_tuple(val.shape_type),
@@ -484,12 +484,12 @@ def find_suitable_rep(root, table_vars) -> TensorFType:
484484
)
485485
)
486486

487-
return BufferizedNDArrayFType(
488-
buf_t=NumpyBufferFType(dtype),
489-
ndim=np.intp(len(result_fields)),
490-
strides_t=TupleFType.from_tuple(
491-
tuple(field_type_map[f] for f in result_fields)
492-
),
487+
# TODO: infer result rep from args
488+
result_rep = type(args_suitable_reps_fields[0][0])
489+
return result_rep(
490+
NumpyBufferFType(dtype),
491+
np.intp(len(result_fields)),
492+
TupleFType.from_tuple(tuple(field_type_map[f] for f in result_fields)),
493493
)
494494
case Aggregate(Literal(op), init, arg, idxs):
495495
init_suitable_rep = find_suitable_rep(init, table_vars)
@@ -504,10 +504,8 @@ def find_suitable_rep(root, table_vars) -> TensorFType:
504504
for f, st in zip(arg.fields, arg_suitable_rep.shape_type, strict=True)
505505
if f not in idxs
506506
)
507-
return BufferizedNDArrayFType(
508-
buf_t=buf_t,
509-
ndim=np.intp(len(strides_t)),
510-
strides_t=TupleFType.from_tuple(strides_t),
507+
return type(arg_suitable_rep)(
508+
buf_t, np.intp(len(strides_t)), TupleFType.from_tuple(strides_t)
511509
)
512510
case LogicTree() as tree:
513511
for child in tree.children:

0 commit comments

Comments
 (0)