Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/finchlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@
DenseLevelFType,
ElementLevelFType,
FiberTensorFType,
dense,
element,
fiber_tensor,
)

__all__ = [
Expand Down Expand Up @@ -158,11 +161,14 @@
"cos",
"cosh",
"defer",
"dense",
"dimension",
"element",
"element_type",
"elementwise",
"equal",
"expand_dims",
"fiber_tensor",
"fill_value",
"fisinstance",
"flatten",
Expand Down
5 changes: 5 additions & 0 deletions src/finchlite/algebra/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def shape_type(self) -> tuple[type, ...]:
e.g. dtypes, formats, or types, and so that we can easily index it."""
...

@abstractmethod
def __init__(self, *args):
"""TensorFType instance initializer."""
...


class Tensor(FTyped, ABC):
"""
Expand Down
22 changes: 10 additions & 12 deletions src/finchlite/autoschedule/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
return_type,
)
from ..codegen import NumpyBufferFType
from ..compile import BufferizedNDArrayFType, ExtentFType, dimension
from ..compile import ExtentFType, dimension
from ..finch_assembly import TupleFType
from ..finch_logic import (
Aggregate,
Expand Down Expand Up @@ -204,7 +204,7 @@ def __call__(
return ntn.Assign(
ntn.Variable(
name,
BufferizedNDArrayFType(
type(val)(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's something to address as a separate larger task

NumpyBufferFType(val.dtype),
val.ndim,
TupleFType.from_tuple(val.shape_type),
Expand Down Expand Up @@ -484,12 +484,12 @@ def find_suitable_rep(root, table_vars) -> TensorFType:
)
)

return BufferizedNDArrayFType(
buf_t=NumpyBufferFType(dtype),
ndim=np.intp(len(result_fields)),
strides_t=TupleFType.from_tuple(
tuple(field_type_map[f] for f in result_fields)
),
# TODO: infer result rep from args
result_rep = type(args_suitable_reps_fields[0][0])
return result_rep(
NumpyBufferFType(dtype),
np.intp(len(result_fields)),
TupleFType.from_tuple(tuple(field_type_map[f] for f in result_fields)),
)
case Aggregate(Literal(op), init, arg, idxs):
init_suitable_rep = find_suitable_rep(init, table_vars)
Expand All @@ -504,10 +504,8 @@ def find_suitable_rep(root, table_vars) -> TensorFType:
for f, st in zip(arg.fields, arg_suitable_rep.shape_type, strict=True)
if f not in idxs
)
return BufferizedNDArrayFType(
buf_t=buf_t,
ndim=np.intp(len(strides_t)),
strides_t=TupleFType.from_tuple(strides_t),
return type(arg_suitable_rep)(
buf_t, np.intp(len(strides_t)), TupleFType.from_tuple(strides_t)
)
case LogicTree() as tree:
for child in tree.children:
Expand Down
3 changes: 2 additions & 1 deletion src/finchlite/autoschedule/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def propagate_into_reformats(root: LogicNode) -> LogicNode:
class Entry:
node: Query
node_pos: int
matched: Query[LogicNode, Reformat] | None = None
matched: Query | None = None
matched_pos: int | None = None

def rule_0(ex: LogicNode) -> LogicNode | None:
Expand All @@ -347,6 +347,7 @@ def rule_0(ex: LogicNode) -> LogicNode | None:
if q.node.lhs not in PostOrderDFS(
Plan(tuple(new_bodies[q.node_pos + 1 :]))
) and isinstance(q.node.rhs, MapJoin | Aggregate | Reorder):
assert isinstance(q.matched.rhs, Reformat)
new_bodies[q.node_pos] = Query(
q.matched.lhs, Reformat(q.matched.rhs.tns, q.node.rhs)
)
Expand Down
4 changes: 4 additions & 0 deletions src/finchlite/codegen/numpy_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def __str__(self):
arr_str = str(self.arr).replace("\n", "")
return f"np_buf({arr_str})"

def __repr__(self):
arr_repr = repr(self.arr).replace("\n", "")
return f"NumpyBuffer({arr_repr})"


class NumpyBufferFType(CBufferFType, NumbaBufferFType, CStackFType):
"""
Expand Down
2 changes: 1 addition & 1 deletion src/finchlite/compile/bufferized_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def asm_repack(self, ctx, lhs, obj):
"""
Repack the buffer from C context.
"""
(self.tns.asm_repack(ctx, lhs.tns, obj.tns),)
self.tns.asm_repack(ctx, lhs.tns, obj.tns)
ctx.exec(
asm.Block(
asm.SetAttr(lhs, "tns", obj.tns),
Expand Down
9 changes: 6 additions & 3 deletions src/finchlite/compile/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@

class FinchTensorFType(TensorFType, ABC):
@abstractmethod
def lower_unwrap(tns):
def lower_unwrap(self, ctx, obj):
"""
Unwrap a tensor view to get the underlying tensor.
This is used to get the original tensor from a tensor view.
"""

@abstractmethod
def lower_increment(tns, val):
def lower_increment(self, ctx, obj, val):
"""
Increment a tensor view with an operation and value.
This updates the tensor at the specified index with the operation and value.
Expand All @@ -47,7 +47,10 @@ def lower_thaw(self, ctx, tns, op):
"""

@abstractmethod
def unfurl(self, ctx, tns, ext, mode, proto): ...
def unfurl(self, ctx, tns, ext, mode, proto):
"""
Unfurl a tensor.
"""


@dataclass(eq=True, frozen=True)
Expand Down
2 changes: 1 addition & 1 deletion src/finchlite/finch_assembly/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def __call__(self, prgm: AssemblyNode):
raise NotImplementedError(f"Unrecognized lhs type: {lhs}")
return None
case GetAttr(obj, attr):
return f"getattr({obj}, {attr})"
return f"{obj}.{attr}"
case SetAttr(obj, attr, val):
return f"setattr({obj}, {attr})"
case Call(Literal(_) as lit, args):
Expand Down
13 changes: 4 additions & 9 deletions src/finchlite/finch_logic/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from typing import Any, Generic, Self, TypeVar
from typing import Any, Self

from ..symbolic import Context, Term, TermTree, ftype, literal_repr
from ..util import qual_str
Expand Down Expand Up @@ -40,11 +40,6 @@ def __str__(self):
return res if res is not None else ctx.emit()


# experiment with type variables
LNVar1 = TypeVar("LNVar1", bound=LogicNode)
LNVar2 = TypeVar("LNVar2", bound=LogicNode)


@dataclass(eq=True, frozen=True)
class LogicTree(LogicNode, TermTree, ABC):
@property
Expand Down Expand Up @@ -332,7 +327,7 @@ def fields(self) -> list[Field]:


@dataclass(eq=True, frozen=True)
class Query(LogicTree, Generic[LNVar1, LNVar2]):
class Query(LogicTree):
"""
Represents a logical AST statement that evaluates `rhs`, binding the result to
`lhs`.
Expand All @@ -342,8 +337,8 @@ class Query(LogicTree, Generic[LNVar1, LNVar2]):
rhs: The right-hand side to evaluate.
"""

lhs: LNVar1
rhs: LNVar2
lhs: LogicNode
rhs: LogicNode

@property
def children(self):
Expand Down
14 changes: 7 additions & 7 deletions src/finchlite/interface/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,24 +332,24 @@ def __ne__(self, other):
register_property(LazyTensor, "asarray", "__attr__", lambda x: x)


def asarray(arg: Any, format="bufferized") -> Any:
def asarray(arg: Any, format=None) -> Any:
"""
Convert given argument and return wrapper type instance.
If input argument is already array type, return unchanged.

Args:
arg: The object to be converted.
format: The format for the result array.

Returns:
The array type result of the given object.
"""
if format != "bufferized":
raise Exception(f"Only bufferized format is now supported, got: {format}")
if format is None:
if hasattr(arg, "asarray"):
return arg.asarray()
return query_property(arg, "asarray", "__attr__")

if hasattr(arg, "asarray"):
return arg.asarray()

return query_property(arg, "asarray", "__attr__")
return format(arg)


def defer(arr) -> LazyTensor:
Expand Down
4 changes: 2 additions & 2 deletions src/finchlite/tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .fiber_tensor import FiberTensor, FiberTensorFType, Level, LevelFType, tensor
from .fiber_tensor import FiberTensor, FiberTensorFType, Level, LevelFType, fiber_tensor
from .level import (
DenseLevel,
DenseLevelFType,
Expand All @@ -19,5 +19,5 @@
"LevelFType",
"dense",
"element",
"tensor",
"fiber_tensor",
]
Loading