Skip to content
Open
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
193dbb1
* Added basic sparse tensor type.
TheRealMichaelWang Oct 28, 2025
101caf3
* Fixed ruff errors
TheRealMichaelWang Oct 28, 2025
ae38183
* Fixed mypy typing issues
TheRealMichaelWang Oct 28, 2025
bddbac9
* Fixed ruff whitespace errors
TheRealMichaelWang Oct 28, 2025
7318931
* Added in dimension check for safety
TheRealMichaelWang Oct 28, 2025
3d27adc
Merge branch 'main' of https://github.com/finch-tensor/finch-tensor-l…
TheRealMichaelWang Nov 4, 2025
335c75c
* Added GetAttribute Einsum IR node
TheRealMichaelWang Nov 5, 2025
6ab5e06
* Added support for printing GetAttribute Einsum IR node
TheRealMichaelWang Nov 5, 2025
e176ef4
* Added einsum interpreter support for evaluating GetAttribute IR node
TheRealMichaelWang Nov 5, 2025
2936eac
* Added support for indirection in ein.Access
TheRealMichaelWang Nov 5, 2025
2c061ab
* Added support for indirection with mutli-dimension indicies. I.e. A…
TheRealMichaelWang Nov 5, 2025
46fda9c
* Added comments explaining indirect access implementation
TheRealMichaelWang Nov 5, 2025
d32ae5b
* Fixed bugs in einsum interpreter that stemmed from newly added supp…
TheRealMichaelWang Nov 6, 2025
5f09be4
* Fixed some bugs
TheRealMichaelWang Nov 6, 2025
860e016
* Revert changes to ein.Access
TheRealMichaelWang Nov 7, 2025
bf43dbb
* Added seperate match case handlers in einsum interpreter loop to ha…
TheRealMichaelWang Nov 7, 2025
2f14a25
* Added support for one index indirect access
TheRealMichaelWang Nov 7, 2025
9c57374
* Added support for multiple indirect indicies in access in einsum in…
TheRealMichaelWang Nov 7, 2025
5ea1ddd
Enhanced einsum interpreter to evaluate tensor shapes and permute dim…
TheRealMichaelWang Nov 7, 2025
71268a8
* Added match case in einsum interpreter loop to hand einsums with in…
TheRealMichaelWang Nov 7, 2025
24ef52c
* Added support for getting shape attribute of a sparse tensor
TheRealMichaelWang Nov 7, 2025
85d81f7
Implemented direct einsum handling for indirect assignments without r…
TheRealMichaelWang Nov 7, 2025
710a2e3
Refactored indirect einsum handling in the interpreter to support red…
TheRealMichaelWang Nov 7, 2025
1dc9667
Removed implementation of indirect einsum with reduction in the inter…
TheRealMichaelWang Nov 11, 2025
d953757
Enhanced EinsumInterpreter to handle cases with fewer indices than di…
TheRealMichaelWang Nov 11, 2025
09c9879
Renamed test for indirect access to clarify focus on elementwise mult…
TheRealMichaelWang Nov 11, 2025
f65e4f9
Add tests for indirect access in einsum operations
TheRealMichaelWang Nov 11, 2025
7549b62
Refactor EinsumInterpreter to handle flat indexing for 1D arrays and …
TheRealMichaelWang Nov 11, 2025
29f5f03
Refactor EinsumInterpreter to improve handling of indirect indexing a…
TheRealMichaelWang Nov 11, 2025
6fe0528
Remove redundant calculation of parent indices in EinsumInterpreter t…
TheRealMichaelWang Nov 11, 2025
8d0a788
Refactor EinsumInterpreter to enhance indirect indexing logic by sepa…
TheRealMichaelWang Nov 12, 2025
cc09bf5
Refactor EinsumInterpreter to improve index handling by grouping ein.…
TheRealMichaelWang Nov 12, 2025
169cebf
Refactor EinsumInterpreter to enhance index classification and groupi…
TheRealMichaelWang Nov 12, 2025
b2f81e6
* Fixed ruff errors
TheRealMichaelWang Nov 12, 2025
c31f7af
* Fixed mypy issues
TheRealMichaelWang Nov 12, 2025
d6f2d25
* Renamed GetAttribute to GetAttr to be consistent with FinchAssembl…
TheRealMichaelWang Nov 12, 2025
9025903
Merge branch 'main' into mw/add-insum-interpreter-support
TheRealMichaelWang Nov 12, 2025
b59cc0a
Refactor EinsumInterpreter to streamline index evaluation and groupin…
TheRealMichaelWang Nov 12, 2025
5d0b80b
* Fixed ruff errors
TheRealMichaelWang Nov 12, 2025
3fe28f7
Merge branch 'main' into mw/add-insum-interpreter-support
willow-ahrens Nov 12, 2025
b387a0a
Merge branch 'main' of https://github.com/finch-tensor/finch-tensor-l…
TheRealMichaelWang Nov 22, 2025
2a3adec
Merge branch 'mw/add-insum-interpreter-support' of https://github.com…
TheRealMichaelWang Nov 22, 2025
64bfe6e
* Fixed minor issues in einsum node.py to adjust for new base einsum …
TheRealMichaelWang Nov 22, 2025
8e69d16
* Changed indirect access implementation to a recursive approach
TheRealMichaelWang Nov 22, 2025
0e527c7
Rewrote ein.Access implementation to access tensors with a mixture of…
TheRealMichaelWang Nov 23, 2025
f163f65
* Added two tests to einsum tests
TheRealMichaelWang Nov 23, 2025
656aabe
* Added comments to enhance clarity in ein.Access for multiple indicies
TheRealMichaelWang Nov 23, 2025
4bf1a78
* Simplified calculation of dest_axes in ein.Access
TheRealMichaelWang Nov 23, 2025
bd8e22c
* Fixed ruff errors
TheRealMichaelWang Nov 23, 2025
dd0167d
* Fixed end of file issues in .gitignore
TheRealMichaelWang Nov 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/finchlite/finch_einsum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Einsum,
EinsumExpr,
EinsumNode,
GetAttr,
Index,
Literal,
Plan,
Expand All @@ -24,6 +25,7 @@
"EinsumNode",
"EinsumScheduler",
"EinsumScheduler",
"GetAttr",
"Index",
"Literal",
"Plan",
Expand Down
222 changes: 219 additions & 3 deletions src/finchlite/finch_einsum/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from ..algebra import overwrite, promote_max, promote_min
from ..symbolic import ftype
from . import nodes as ein

nary_ops = {
Expand Down Expand Up @@ -78,6 +79,11 @@ def __init__(self, xp=None, bindings=None, loops=None):
self.loops = loops

def __call__(self, node):
from ..tensor import (
SparseTensor,
SparseTensorFType,
)

xp = self.xp
match node:
case ein.Literal(val):
Expand All @@ -92,26 +98,230 @@ def __call__(self, node):
func = getattr(xp, nary_ops[func])
vals = [self(arg) for arg in args]
return func(*vals)
case ein.Access(tns, idxs):

# access a tensor with only indices
case ein.Access(tns, idxs) if all(
isinstance(idx, ein.Index) for idx in idxs
):
assert len(idxs) == len(set(idxs))
assert self.loops is not None

# convert named idxs to positional, integer indices
perm = [idxs.index(idx) for idx in self.loops if idx in idxs]

tns = self(tns) # evaluate the tensor

# if there are fewer indicies than dimensions, add the remaining
# dimensions as if they werent permutated
if hasattr(tns, "ndim") and len(perm) < tns.ndim:
perm += list(range(len(perm), tns.ndim))

tns = xp.permute_dims(tns, perm) # permute the dimensions
return xp.expand_dims(
tns,
[i for i in range(len(self.loops)) if self.loops[i] not in idxs],
)

# access a tensor with only one indirect access index
case ein.Access(tns, idxs) if len(idxs) == 1:
idx = self(idxs[0])
tns = self(tns) # evaluate the tensor

flat_idx = (
idx if idx.ndim == 1 else xp.ravel_multi_index(idx.T, tns.shape)
)
return tns.flat[flat_idx] # return a 1-d array by definition

# access a tensor with a mixture of indices and other expressions
case ein.Access(tns, idxs):
assert self.loops is not None
true_idxs = node.get_idxs() # true field iteratior indicies
assert all(isinstance(idx, ein.Index) for idx in true_idxs)

# evaluate the tensor to access
tns = self(tns)
assert len(idxs) == len(tns.shape)

# Classify indices and determine their parent groups
idx_sizes = [] # Track size of each index dimension

# For each index, determine which true_idx it depends on
parent_indices = []
for i, idx in enumerate(idxs):
if isinstance(idx, ein.Index):
parent_indices.append(idx)
idx_sizes.append(tns.shape[i])
else:
child_idxs = idx.get_idxs()
parent_indices.append(
list(child_idxs)[0] if len(child_idxs) == 1 else idx
)
idx_sizes.append(None) # Will be determined later

# Find unique parent groups and their positions
# Create a mapping from parent to group index
unique_parents = []
parent_to_group = {}
for parent in parent_indices:
found = False
for i, up in enumerate(unique_parents):
# For ein.Index, compare by name; for others by identity
if isinstance(parent, ein.Index) and isinstance(up, ein.Index):
if parent.name == up.name:
parent_to_group[id(parent)] = i
found = True
break
elif parent is up:
parent_to_group[id(parent)] = i
found = True
break
if not found:
parent_to_group[id(parent)] = len(unique_parents)
unique_parents.append(parent)

# Build group positions from the mapping
group_positions = [[] for _ in range(len(unique_parents))]
for i, parent in enumerate(parent_indices):
group_idx = parent_to_group[id(parent)]
group_positions[group_idx].append(i)

# Build ranges for cartesian product
group_ranges = []
for positions in group_positions:
first_pos = positions[0]
if isinstance(idxs[first_pos], ein.Index):
group_ranges.append(xp.arange(tns.shape[first_pos]))
else:
indirect_vals = self(idxs[first_pos]).flatten()
group_ranges.append(xp.arange(len(indirect_vals)))
# Update sizes for all positions in this group
for p in positions:
idx_sizes[p] = len(indirect_vals)

# Compute cartesian product
group_grids = xp.meshgrid(*group_ranges, indexing="ij")
group_combos = xp.stack([g.flatten() for g in group_grids], axis=-1)

# Build index combinations
combo_idxs = xp.empty(
(group_combos.shape[0], len(idxs)), dtype=xp.int64
)

# Pre-evaluate all indirect indices
indirect_vals_cache = {}
for i, idx in enumerate(idxs):
if not isinstance(idx, ein.Index):
indirect_vals_cache[i] = self(idx).flatten()

# Fill combo_idxs using vectorized operations where possible
for group_idx, positions in enumerate(group_positions):
group_iterations = group_combos[:, group_idx]
for pos in positions:
if isinstance(idxs[pos], ein.Index):
combo_idxs[:, pos] = group_iterations
else:
combo_idxs[:, pos] = indirect_vals_cache[pos][
group_iterations
]

# evaluate the output tensor as a flat array
flat_idx = xp.ravel_multi_index(combo_idxs.T, tns.shape)
tns = xp.take(tns, flat_idx)

# calculate child idxs, idxs computed using the parent "true idxs"
child_idxs = {
parent_idx: [
child_idx
for child_idx in idxs
if (parent_idx in child_idx.get_idxs())
]
for parent_idx in true_idxs
}

# we assert that all the indirect access indicies
# from the parent idxs have the same size
child_idxs_size = {
parent_idx: [
idx_sizes[idxs.index(child_idx)]
for child_idx in child_idxs[parent_idx]
]
for parent_idx in true_idxs
}
assert all(
child_idxs_size[parent_idx].count(child_idxs_size[parent_idx][0])
== len(child_idxs_size[parent_idx])
for parent_idx in true_idxs
)

# a mapping from each idx to its axis wrt to current shape
idxs_axis = {idx: i for i, idx in enumerate(idxs)}

true_idxs = list(true_idxs)
true_idxs = sorted(
true_idxs, key=lambda idx: idxs_axis[child_idxs[idx][0]]
)

# calculate the final shape of the tensor
# we merge the child idxs to get the final shape
# that matches the true idxs
final_shape = tuple(
idx_sizes[idxs.index(child_idxs[parent_idx][0])]
for parent_idx in true_idxs
)
tns = tns.reshape(final_shape)

# permute and broadcast the tensor to be
# compatible with rest of expression
perm = [true_idxs.index(idx) for idx in self.loops if idx in true_idxs]
tns = xp.permute_dims(tns, perm)
return xp.expand_dims(
tns,
[i for i in range(len(self.loops)) if self.loops[i] not in idxs],
[
i
for i in range(len(self.loops))
if self.loops[i] not in true_idxs
],
)

case ein.Plan(bodies):
res = None
for body in bodies:
res = self(body)
return res
case ein.Produces(args):
return tuple(self(arg) for arg in args)
case ein.Einsum(op, ein.Alias(tns), idxs, arg):

# get non-zero elements/data array of a sparse tensor
case ein.GetAttr(obj, ein.Literal("elems"), _):
obj = self(obj)
assert isinstance(ftype(obj), SparseTensorFType)
assert isinstance(obj, SparseTensor)
return obj.data
# get coord array of a sparse tensor
case ein.GetAttr(obj, ein.Literal("coords"), dim):
obj = self(obj)
assert isinstance(ftype(obj), SparseTensorFType)
assert isinstance(obj, SparseTensor)

# return the coord array for the given dimension or all dimensions
return obj.coords if dim is None else obj.coords[:, dim]
# gets the shape of a sparse tensor at a given dimension
case ein.GetAttr(obj, ein.Literal("shape"), dim):
obj = self(obj)
assert isinstance(ftype(obj), SparseTensorFType)
assert isinstance(obj, SparseTensor)
assert dim is not None

# return the shape for the given dimension
return obj.shape[dim]

# standard einsum
case ein.Einsum(op, ein.Alias(tns), idxs, arg) if all(
isinstance(idx, ein.Index) for idx in idxs
):
# This is the main entry point for einsum execution
loops = arg.get_idxs()

assert set(idxs).issubset(loops)
loops = sorted(loops, key=lambda x: x.name)
ctx = EinsumInterpreter(self.xp, self.bindings, loops)
Expand All @@ -128,5 +338,11 @@ def __call__(self, node):
axis = [dropped.index(idx) for idx in idxs]
self.bindings[tns] = xp.permute_dims(val, axis)
return (tns,)

# indirect einsum
case ein.Einsum(op, ein.Alias(tns), idxs, arg):
raise NotImplementedError(
"Indirect einsum assignment is not implemented"
)
case _:
raise ValueError(f"Unknown einsum type: {type(node)}")
46 changes: 46 additions & 0 deletions src/finchlite/finch_einsum/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,42 @@ def get_idxs(self) -> set["Index"]:
return idxs


@dataclass(eq=True, frozen=True)
class GetAttr(EinsumExpr, EinsumTree):
"""
Gets an attribute of a tensor.
Attributes:
obj: The object to get the attribute from.
attr: The name of the attribute to get.

dim: The dimension to get the attribute from.
Note this is an integer index, not a named index.
"""

obj: EinsumExpr
attr: Literal
dim: int | None

@classmethod
def from_children(cls, *children: Term) -> Self:
# Expects 3 children: obj, attr, idx
if len(children) != 3:
raise ValueError("GetAttribute expects 3 children (obj + attr + idx)")
obj = cast(EinsumExpr, children[0])
attr = cast(Literal, children[1])
dim = cast(int | None, children[2])
return cls(obj, attr, dim)

@property
def children(self):
return [self.obj, self.attr, self.dim]

def get_idxs(self) -> set["Index"]:
idxs = set()
idxs.update(self.obj.get_idxs())
return idxs


@dataclass(eq=True, frozen=True)
class Einsum(EinsumTree):
"""
Expand Down Expand Up @@ -214,6 +250,12 @@ def from_children(cls, *children: Term) -> Self:
def children(self):
return [self.op, self.tns, self.idxs, self.arg]

def get_idxs(self) -> set["Index"]:
idxs = set()
for idx in self.idxs:
idxs.update(idx.get_idxs())
return idxs


@dataclass(eq=True, frozen=True)
class Plan(EinsumTree):
Expand Down Expand Up @@ -334,6 +376,10 @@ def __call__(self, prgm: EinsumNode):
if len(args) == 1 and fn.val in unary_strs:
return f"{unary_strs[fn.val]}{args_e[0]}"
return f"{self(fn)}({', '.join(args_e)})"
case GetAttr(obj, attr, idx):
if idx is not None:
return f"{self(obj)}.{self(attr)}[{idx}]"
return f"{self(obj)}.{self(attr)}"
case Einsum(op, tns, idxs, arg):
op_str = infix_strs.get(op.val, op.val.__name__)
self.exec(
Expand Down
3 changes: 3 additions & 0 deletions src/finchlite/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
dense,
element,
)
from .sparse_tensor import SparseTensor, SparseTensorFType

__all__ = [
"DenseLevel",
Expand All @@ -17,6 +18,8 @@
"FiberTensorFType",
"Level",
"LevelFType",
"SparseTensor",
"SparseTensorFType",
"dense",
"element",
"tensor",
Expand Down
Loading