Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/finchlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@
DenseLevelFType,
ElementLevelFType,
FiberTensorFType,
dense,
element,
fiber_tensor,
)

__all__ = [
Expand Down Expand Up @@ -188,17 +191,20 @@
"cos",
"cosh",
"defer",
"dense",
"dimension",
"divide",
"einop",
"einsum",
"element",
"element_type",
"elementwise",
"equal",
"exp",
"expand_dims",
"expm1",
"extent",
"fiber_tensor",
"fill_value",
"fisinstance",
"flatten",
Expand Down
21 changes: 21 additions & 0 deletions src/finchlite/algebra/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,27 @@ def shape_type(self) -> tuple[type, ...]:
e.g. dtypes, formats, or types, and so that we can easily index it."""
...

@abstractmethod
def __init__(self, *args):
"""TensorFType instance initializer."""
...

# TODO: Remove and properly infer result rep
def add_levels(self, idxs: list[int]):
raise Exception("TODO: to remove")

# TODO: Remove and properly infer result rep
def remove_levels(self, idxs: list[int]):
raise Exception("TODO: to remove")

# TODO: Remove and properly infer result rep
def to_kwargs(self) -> dict[str, Any]:
raise Exception("TODO: to remove")

# TODO: Remove and properly infer result rep
def from_kwargs(self, **kwargs):
raise Exception("TODO: to remove")


class Tensor(FTyped, ABC):
"""
Expand Down
55 changes: 33 additions & 22 deletions src/finchlite/autoschedule/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
return_type,
)
from ..codegen import NumpyBufferFType
from ..compile import BufferizedNDArrayFType, ExtentFType, dimension
from ..finch_assembly import TupleFType
from ..compile import ExtentFType, dimension
from ..finch_logic import (
Aggregate,
Alias,
Expand Down Expand Up @@ -214,11 +213,7 @@ def __call__(
return ntn.Assign(
ntn.Variable(
name,
BufferizedNDArrayFType(
NumpyBufferFType(val.dtype),
val.ndim,
TupleFType.from_tuple(val.shape_type),
),
val.from_kwargs(val.to_kwargs()),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Proper handling of the format inference is a separate larger task (to map traits.jl), here I only convert it to a "dict of attributes" where I can override some of them (for example promoted dtype from two inputs).

),
compile_logic_constant(tns),
)
Expand Down Expand Up @@ -498,13 +493,19 @@ def find_suitable_rep(root, table_vars) -> TensorFType:
)
)

return BufferizedNDArrayFType(
buf_t=NumpyBufferFType(dtype),
# TODO: properly infer result rep from args
result_rep, fields = args_suitable_reps_fields[0]
levels_to_add = [
idx for idx, f in enumerate(result_fields) if f not in fields
]
result_rep = result_rep.add_levels(levels_to_add)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in the comment above - temporary selection of format (including removing/adding levels), this will be moved to a separate module.

kwargs = result_rep.to_kwargs()
kwargs.update(
element_type=NumpyBufferFType(dtype),
ndim=np.intp(len(result_fields)),
strides_t=TupleFType.from_tuple(
tuple(field_type_map[f] for f in result_fields)
),
dimension_type=tuple(field_type_map[f] for f in result_fields),
)
return result_rep.from_kwargs(**kwargs)
case Aggregate(Literal(op), init, arg, idxs):
init_suitable_rep = find_suitable_rep(init, table_vars)
arg_suitable_rep = find_suitable_rep(arg, table_vars)
Expand All @@ -513,16 +514,24 @@ def find_suitable_rep(root, table_vars) -> TensorFType:
op, init_suitable_rep.element_type, arg_suitable_rep.element_type
)
)
strides_t = tuple(
st
for f, st in zip(arg.fields, arg_suitable_rep.shape_type, strict=True)
if f not in idxs
)
return BufferizedNDArrayFType(
buf_t=buf_t,
# TODO: properly infer result rep from args
levels_to_remove = []
strides_t = []
for idx, (f, st) in enumerate(
zip(arg.fields, arg_suitable_rep.shape_type, strict=True)
):
if f not in idxs:
strides_t.append(st)
else:
levels_to_remove.append(idx)
arg_suitable_rep = arg_suitable_rep.remove_levels(levels_to_remove)
kwargs = arg_suitable_rep.to_kwargs()
kwargs.update(
buffer_type=buf_t,
ndim=np.intp(len(strides_t)),
strides_t=TupleFType.from_tuple(strides_t),
dimension_type=tuple(strides_t),
)
return arg_suitable_rep.from_kwargs(**kwargs)
case LogicTree() as tree:
for child in tree.children:
suitable_rep = find_suitable_rep(child, table_vars)
Expand Down Expand Up @@ -555,11 +564,13 @@ class LogicCompiler:
def __init__(self):
self.ll = LogicLowerer()

def __call__(self, prgm: LogicNode) -> tuple[ntn.NotationNode, dict[Alias, Table]]:
def __call__(
self, prgm: LogicNode
) -> tuple[ntn.NotationNode, dict[Alias, ntn.Variable], dict[Alias, Table]]:
prgm, table_vars, slot_vars, dim_size_vars, tables, field_relabels = (
record_tables(prgm)
)
lowered_prgm = self.ll(
prgm, table_vars, slot_vars, dim_size_vars, field_relabels
)
return merge_blocks(lowered_prgm), tables
return merge_blocks(lowered_prgm), table_vars, tables
7 changes: 4 additions & 3 deletions src/finchlite/autoschedule/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def compile_plan(
return None
case lgc.Query(
lgc.Alias(name),
lgc.Aggregate(lgc.Literal(operation), lgc.Literal(init), arg, _),
lgc.Aggregate(lgc.Literal(operation), lgc.Literal(init), arg, _) as agg,
):
einidxs = tuple(ein.Index(field.name) for field in node.rhs.fields)
einidxs = tuple(ein.Index(field.name) for field in agg.fields)
my_bodies = []
if init != init_value(operation, type(init)):
my_bodies.append(
Expand All @@ -52,11 +52,12 @@ def compile_plan(
)
return ein.Plan(tuple(my_bodies))
case lgc.Query(lgc.Alias(name), rhs):
assert isinstance(rhs, lgc.LogicExpression)
einarg = self.compile_operand(rhs)
return ein.Einsum(
op=ein.Literal(overwrite),
tns=ein.Alias(name),
idxs=tuple(ein.Index(field.name) for field in node.rhs.fields),
idxs=tuple(ein.Index(field.name) for field in rhs.fields),
arg=einarg,
)

Expand Down
8 changes: 4 additions & 4 deletions src/finchlite/codegen/numba_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,11 +716,11 @@ def struct_numba_setattr(fmt: AssemblyStructFType, ctx, obj, attr, val):


def struct_construct_from_numba(fmt: AssemblyStructFType, numba_struct):
args = [
construct_from_numba(field_type, getattr(numba_struct, name))
kwargs = {
name: construct_from_numba(field_type, getattr(numba_struct, name))
for (name, field_type) in fmt.struct_fields
]
return fmt(*args)
}
return fmt(**kwargs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

formats now have "multiple constructors", whether we construct from Numba or in the facing API



register_property(
Expand Down
4 changes: 4 additions & 0 deletions src/finchlite/codegen/numpy_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def __str__(self):
arr_str = str(self.arr).replace("\n", "")
return f"np_buf({arr_str})"

def __repr__(self):
arr_repr = repr(self.arr).replace("\n", "")
return f"NumpyBuffer({arr_repr})"


class NumpyBufferFType(CBufferFType, NumbaBufferFType, CStackFType):
"""
Expand Down
Loading