Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
193dbb1
* Added basic sparse tensor type.
TheRealMichaelWang Oct 28, 2025
101caf3
* Fixed ruff errors
TheRealMichaelWang Oct 28, 2025
ae38183
* Fixed mypy typing issues
TheRealMichaelWang Oct 28, 2025
bddbac9
* Fixed ruff whitespace errors
TheRealMichaelWang Oct 28, 2025
7318931
* Added in dimension check for safety
TheRealMichaelWang Oct 28, 2025
3d27adc
Merge branch 'main' of https://github.com/finch-tensor/finch-tensor-l…
TheRealMichaelWang Nov 4, 2025
335c75c
* Added GetAttribute Einsum IR node
TheRealMichaelWang Nov 5, 2025
6ab5e06
* Added support for printing GetAttribute Einsum IR node
TheRealMichaelWang Nov 5, 2025
e176ef4
* Added einsum interpreter support for evaluating GetAttribute IR node
TheRealMichaelWang Nov 5, 2025
2936eac
* Added support for indirection in ein.Access
TheRealMichaelWang Nov 5, 2025
2c061ab
* Added support for indirection with mutli-dimension indicies. I.e. A…
TheRealMichaelWang Nov 5, 2025
46fda9c
* Added comments explaining indirect access implementation
TheRealMichaelWang Nov 5, 2025
d32ae5b
* Fixed bugs in einsum interpreter that stemmed from newly added supp…
TheRealMichaelWang Nov 6, 2025
5f09be4
* Fixed some bugs
TheRealMichaelWang Nov 6, 2025
860e016
* Revert changes to ein.Access
TheRealMichaelWang Nov 7, 2025
bf43dbb
* Added seperate match case handlers in einsum interpreter loop to ha…
TheRealMichaelWang Nov 7, 2025
2f14a25
* Added support for one index indirect access
TheRealMichaelWang Nov 7, 2025
9c57374
* Added support for multiple indirect indicies in access in einsum in…
TheRealMichaelWang Nov 7, 2025
5ea1ddd
Enhanced einsum interpreter to evaluate tensor shapes and permute dim…
TheRealMichaelWang Nov 7, 2025
71268a8
* Added match case in einsum interpreter loop to hand einsums with in…
TheRealMichaelWang Nov 7, 2025
24ef52c
* Added support for getting shape attribute of a sparse tensor
TheRealMichaelWang Nov 7, 2025
85d81f7
Implemented direct einsum handling for indirect assignments without r…
TheRealMichaelWang Nov 7, 2025
710a2e3
Refactored indirect einsum handling in the interpreter to support red…
TheRealMichaelWang Nov 7, 2025
1dc9667
Removed implementation of indirect einsum with reduction in the inter…
TheRealMichaelWang Nov 11, 2025
d953757
Enhanced EinsumInterpreter to handle cases with fewer indices than di…
TheRealMichaelWang Nov 11, 2025
09c9879
Renamed test for indirect access to clarify focus on elementwise mult…
TheRealMichaelWang Nov 11, 2025
f65e4f9
Add tests for indirect access in einsum operations
TheRealMichaelWang Nov 11, 2025
7549b62
Refactor EinsumInterpreter to handle flat indexing for 1D arrays and …
TheRealMichaelWang Nov 11, 2025
29f5f03
Refactor EinsumInterpreter to improve handling of indirect indexing a…
TheRealMichaelWang Nov 11, 2025
6fe0528
Remove redundant calculation of parent indices in EinsumInterpreter t…
TheRealMichaelWang Nov 11, 2025
8d0a788
Refactor EinsumInterpreter to enhance indirect indexing logic by sepa…
TheRealMichaelWang Nov 12, 2025
cc09bf5
Refactor EinsumInterpreter to improve index handling by grouping ein.…
TheRealMichaelWang Nov 12, 2025
169cebf
Refactor EinsumInterpreter to enhance index classification and groupi…
TheRealMichaelWang Nov 12, 2025
b2f81e6
* Fixed ruff errors
TheRealMichaelWang Nov 12, 2025
c31f7af
* Fixed mypy issues
TheRealMichaelWang Nov 12, 2025
d6f2d25
* Renamed GetAttribute to GetAttr to be consistent with FinchAssembl…
TheRealMichaelWang Nov 12, 2025
9025903
Merge branch 'main' into mw/add-insum-interpreter-support
TheRealMichaelWang Nov 12, 2025
b59cc0a
Refactor EinsumInterpreter to streamline index evaluation and groupin…
TheRealMichaelWang Nov 12, 2025
5d0b80b
* Fixed ruff errors
TheRealMichaelWang Nov 12, 2025
3fe28f7
Merge branch 'main' into mw/add-insum-interpreter-support
willow-ahrens Nov 12, 2025
b387a0a
Merge branch 'main' of https://github.com/finch-tensor/finch-tensor-l…
TheRealMichaelWang Nov 22, 2025
2a3adec
Merge branch 'mw/add-insum-interpreter-support' of https://github.com…
TheRealMichaelWang Nov 22, 2025
64bfe6e
* Fixed minor issues in einsum node.py to adjust for new base einsum …
TheRealMichaelWang Nov 22, 2025
8e69d16
* Changed indirect access implementation to a recursive approach
TheRealMichaelWang Nov 22, 2025
0e527c7
Rewrote ein.Access implementation to access tensors with a mixture of…
TheRealMichaelWang Nov 23, 2025
f163f65
* Added two tests to einsum tests
TheRealMichaelWang Nov 23, 2025
656aabe
* Added comments to enhance clarity in ein.Access for multiple indicies
TheRealMichaelWang Nov 23, 2025
4bf1a78
* Simplified calculation of dest_axes in ein.Access
TheRealMichaelWang Nov 23, 2025
bd8e22c
* Fixed ruff errors
TheRealMichaelWang Nov 23, 2025
dd0167d
* Fixed end of file issues in .gitignore
TheRealMichaelWang Nov 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,6 @@ cython_debug/
.pixi/
*.egg-info
pixi.lock

# vscode debugging configurations
.vscode/
2 changes: 2 additions & 0 deletions src/finchlite/finch_einsum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
EinsumExpression,
EinsumNode,
EinsumStatement,
GetAttr,
Index,
Literal,
Plan,
Expand All @@ -26,6 +27,7 @@
"EinsumScheduler",
"EinsumScheduler",
"EinsumStatement",
"GetAttr",
"Index",
"Literal",
"Plan",
Expand Down
132 changes: 125 additions & 7 deletions src/finchlite/finch_einsum/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from ..algebra import overwrite, promote_max, promote_min
from ..symbolic import ftype
from . import nodes as ein

nary_ops = {
Expand Down Expand Up @@ -78,6 +79,11 @@ def __init__(self, xp=None, bindings=None, loops=None):
self.loops = loops

def __call__(self, node):
from ..tensor import (
SparseTensor,
SparseTensorFType,
)

xp = self.xp
match node:
case ein.Literal(val):
Expand All @@ -92,26 +98,132 @@ def __call__(self, node):
func = getattr(xp, nary_ops[func])
vals = [self(arg) for arg in args]
return func(*vals)

# access a tensor with only one indirect access index
case ein.Access(tns, idxs) if len(idxs) == 1 and not isinstance(
idxs[0], ein.Index
):
idx = self(idxs[0])
tns = self(tns) # evaluate the tensor

flat_idx = (
idx if idx.ndim == 1 else xp.ravel_multi_index(idx.T, tns.shape)
)
return tns.flat[flat_idx] # return a 1-d array by definition

# access a tensor with a mixture of indices and other expressions
case ein.Access(tns, idxs):
assert len(idxs) == len(set(idxs))
assert self.loops is not None
perm = [idxs.index(idx) for idx in self.loops if idx in idxs]

tns = self(tns)
tns = xp.permute_dims(tns, perm)
return xp.expand_dims(
tns,
[i for i in range(len(self.loops)) if self.loops[i] not in idxs],
indirect_idxs = [idx for idx in idxs if not isinstance(idx, ein.Index)]

# base case: no indirect indices, just permute the dimensions
if len(indirect_idxs) == 0:
perm = [idxs.index(idx) for idx in self.loops if idx in idxs]
if hasattr(tns, "ndim") and len(perm) < tns.ndim:
perm += list(range(len(perm), tns.ndim))

tns = xp.permute_dims(tns, perm) # permute the dimensions
return xp.expand_dims(
tns,
[
i
for i in range(len(self.loops))
if self.loops[i] not in idxs
],
)

start_index = idxs.index(
indirect_idxs[0]
) # index of first indirect access
iterator_idxs = indirect_idxs[
0
].get_idxs() # iterator indicies of the first indirect access
assert len(iterator_idxs) == 1

# get the axes of the idxs that are associated
# with the current iterator indicies
target_axes = [
i
for i, idx in enumerate(idxs[start_index:], start_index)
if idx.get_idxs().issubset(iterator_idxs)
]

# get associated access indicies w/ the first indirect access
current_idxs = [idxs[i] for i in target_axes]

# evaluate the associated access indicies
evaled_idxs = [
xp.arange(tns.shape[idxs.index(idx)])
if isinstance(idx, ein.Index)
else self(idx).flat
for idx in current_idxs
]

dest_axes = list(range(len(current_idxs)))
tns = xp.moveaxis(tns, target_axes, dest_axes)

# access the tensor with the evaled idxs
tns = tns[tuple(evaled_idxs)]

# restore original tensor axis order
tns = xp.moveaxis(tns, source=0, destination=target_axes[0])

# we recursiveley call the interpreter with the remaining idxs
iterator_idx = next(iter(iterator_idxs))
new_idxs = (
list(idxs[:start_index])
+ [iterator_idx]
+ [
idx
for idx in idxs[start_index + 1 :]
if idx not in current_idxs
]
)

new_access = ein.Access(ein.Literal(tns), new_idxs)
return self(new_access)

case ein.Plan(bodies):
res = None
for body in bodies:
res = self(body)
return res
case ein.Produces(args):
return tuple(self(arg) for arg in args)
case ein.Einsum(op, ein.Alias(tns), idxs, arg):

# get non-zero elements/data array of a sparse tensor
case ein.GetAttr(obj, ein.Literal("elems"), _):
obj = self(obj)
assert isinstance(ftype(obj), SparseTensorFType)
assert isinstance(obj, SparseTensor)
return obj.data
# get coord array of a sparse tensor
case ein.GetAttr(obj, ein.Literal("coords"), dim):
obj = self(obj)
assert isinstance(ftype(obj), SparseTensorFType)
assert isinstance(obj, SparseTensor)

# return the coord array for the given dimension or all dimensions
return obj.coords if dim is None else obj.coords[:, dim].flat
# gets the shape of a sparse tensor at a given dimension
case ein.GetAttr(obj, ein.Literal("shape"), dim):
obj = self(obj)
assert isinstance(ftype(obj), SparseTensorFType)
assert isinstance(obj, SparseTensor)
assert dim is not None

# return the shape for the given dimension
return obj.shape[dim]

# standard einsum
case ein.Einsum(op, ein.Alias(tns), idxs, arg) if all(
isinstance(idx, ein.Index) for idx in idxs
):
# This is the main entry point for einsum execution
loops = arg.get_idxs()

assert set(idxs).issubset(loops)
loops = sorted(loops, key=lambda x: x.name)
ctx = EinsumInterpreter(self.xp, self.bindings, loops)
Expand All @@ -128,5 +240,11 @@ def __call__(self, node):
axis = [dropped.index(idx) for idx in idxs]
self.bindings[tns] = xp.permute_dims(val, axis)
return (tns,)

# indirect einsum
case ein.Einsum(op, ein.Alias(tns), idxs, arg):
raise NotImplementedError(
"Indirect einsum assignment is not implemented"
)
case _:
raise ValueError(f"Unknown einsum type: {type(node)}")
46 changes: 46 additions & 0 deletions src/finchlite/finch_einsum/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,42 @@ def get_idxs(self) -> set["Index"]:
return idxs


@dataclass(eq=True, frozen=True)
class GetAttr(EinsumExpression, EinsumTree):
"""
Gets an attribute of a tensor.
Attributes:
obj: The object to get the attribute from.
attr: The name of the attribute to get.

dim: The dimension to get the attribute from.
Note this is an integer index, not a named index.
"""

obj: EinsumExpression
attr: Literal
dim: int | None

@classmethod
def from_children(cls, *children: Term) -> Self:
# Expects 3 children: obj, attr, idx
if len(children) != 3:
raise ValueError("GetAttribute expects 3 children (obj + attr + idx)")
obj = cast(EinsumExpression, children[0])
attr = cast(Literal, children[1])
dim = cast(int | None, children[2])
return cls(obj, attr, dim)

@property
def children(self):
return [self.obj, self.attr, self.dim]

def get_idxs(self) -> set["Index"]:
idxs = set()
idxs.update(self.obj.get_idxs())
return idxs


@dataclass(eq=True, frozen=True)
class Einsum(EinsumTree, EinsumStatement):
"""
Expand Down Expand Up @@ -224,6 +260,12 @@ def from_children(cls, *children: Term) -> Self:
def children(self):
return [self.op, self.tns, self.idxs, self.arg]

def get_idxs(self) -> set["Index"]:
idxs = set()
for idx in self.idxs:
idxs.update(idx.get_idxs())
return idxs


@dataclass(eq=True, frozen=True)
class Plan(EinsumTree, EinsumStatement):
Expand Down Expand Up @@ -344,6 +386,10 @@ def __call__(self, prgm: EinsumNode):
if len(args) == 1 and fn.val in unary_strs:
return f"{unary_strs[fn.val]}{args_e[0]}"
return f"{self(fn)}({', '.join(args_e)})"
case GetAttr(obj, attr, idx):
if idx is not None:
return f"{self(obj)}.{self(attr)}[{idx}]"
return f"{self(obj)}.{self(attr)}"
case Einsum(op, tns, idxs, arg):
op_str = infix_strs.get(op.val, op.val.__name__)
self.exec(
Expand Down
3 changes: 3 additions & 0 deletions src/finchlite/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
dense,
element,
)
from .sparse_tensor import SparseTensor, SparseTensorFType

__all__ = [
"DenseLevel",
Expand All @@ -17,6 +18,8 @@
"FiberTensorFType",
"Level",
"LevelFType",
"SparseTensor",
"SparseTensorFType",
"dense",
"element",
"tensor",
Expand Down
103 changes: 103 additions & 0 deletions src/finchlite/tensor/sparse_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import numpy as np

from finchlite.algebra import TensorFType
from finchlite.interface.eager import EagerTensor


class SparseTensorFType(TensorFType):
def __init__(self, shape: tuple, element_type: type):
self.shape = shape
self._element_type = element_type

def __eq__(self, other):
if not isinstance(other, SparseTensorFType):
return False
return self.shape == other.shape and self.element_type == other.element_type

def __hash__(self):
return hash((self.shape, self.element_type))

@property
def ndim(self):
return len(self.shape)

@property
def shape_type(self):
return self.shape

@property
def element_type(self):
return self._element_type

@property
def fill_value(self):
return 0


# currently implemented with COO tensor
class SparseTensor(EagerTensor):
def __init__(
self,
data: np.typing.NDArray,
coords: np.typing.NDArray,
shape: tuple,
element_type=np.float64,
):
if data.shape[0] != coords.shape[0]:
raise ValueError("data and coords must have the same number of rows")

self.coords = coords
self.data = data
self._shape = shape
self._element_type = element_type

# converts an eager tensor to a sparse tensor
@classmethod
def from_dense_tensor(cls, dense_tensor: np.ndarray):
coords = np.where(dense_tensor != 0)
data = dense_tensor[coords]
shape = dense_tensor.shape
element_type = dense_tensor.dtype.type
coords_array = np.array(coords).T
return cls(data, coords_array, shape, element_type)

@property
def ftype(self):
return SparseTensorFType(self.shape, self._element_type)

@property
def shape(self):
return self._shape

@property
def ndim(self) -> np.intp:
return np.intp(len(self._shape))

# calculates the ratio of non-zero elements to the total number of elements
@property
def density(self):
return self.coords.shape[0] / np.prod(self.shape)

def __getitem__(self, idx: tuple):
if len(idx) != self.ndim:
raise ValueError(f"Index must have {self.ndim} dimensions")

# coords is a 2D array where each row is a coordinate
mask = np.all(self.coords == idx, axis=1)
matching_indices = np.where(mask)[0]

if len(matching_indices) > 0:
return self.data[matching_indices[0]]
return 0

def __str__(self):
return (
f"SparseTensor(data={self.data}, coords={self.coords},"
f" shape={self.shape}, element_type={self._element_type})"
)

def to_dense(self) -> np.ndarray:
dense_tensor = np.zeros(self.shape, dtype=self._element_type)
for i in range(self.coords.shape[0]):
dense_tensor[tuple(self.coords[i])] = self.data[i]
return dense_tensor
Loading