Skip to content
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
193dbb1
* Added basic sparse tensor type.
TheRealMichaelWang Oct 28, 2025
101caf3
* Fixed ruff errors
TheRealMichaelWang Oct 28, 2025
ae38183
* Fixed mypy typing issues
TheRealMichaelWang Oct 28, 2025
bddbac9
* Fixed ruff whitespace errors
TheRealMichaelWang Oct 28, 2025
7318931
* Added in dimension check for safety
TheRealMichaelWang Oct 28, 2025
3d27adc
Merge branch 'main' of https://github.com/finch-tensor/finch-tensor-l…
TheRealMichaelWang Nov 4, 2025
335c75c
* Added GetAttribute Einsum IR node
TheRealMichaelWang Nov 5, 2025
6ab5e06
* Added support for printing GetAttribute Einsum IR node
TheRealMichaelWang Nov 5, 2025
e176ef4
* Added einsum interpreter support for evaluating GetAttribute IR node
TheRealMichaelWang Nov 5, 2025
2936eac
* Added support for indirection in ein.Access
TheRealMichaelWang Nov 5, 2025
2c061ab
* Added support for indirection with mutli-dimension indicies. I.e. A…
TheRealMichaelWang Nov 5, 2025
46fda9c
* Added comments explaining indirect access implementation
TheRealMichaelWang Nov 5, 2025
d32ae5b
* Fixed bugs in einsum interpreter that stemmed from newly added supp…
TheRealMichaelWang Nov 6, 2025
5f09be4
* Fixed some bugs
TheRealMichaelWang Nov 6, 2025
860e016
* Revert changes to ein.Access
TheRealMichaelWang Nov 7, 2025
bf43dbb
* Added seperate match case handlers in einsum interpreter loop to ha…
TheRealMichaelWang Nov 7, 2025
2f14a25
* Added support for one index indirect access
TheRealMichaelWang Nov 7, 2025
9c57374
* Added support for multiple indirect indicies in access in einsum in…
TheRealMichaelWang Nov 7, 2025
5ea1ddd
Enhanced einsum interpreter to evaluate tensor shapes and permute dim…
TheRealMichaelWang Nov 7, 2025
71268a8
* Added match case in einsum interpreter loop to hand einsums with in…
TheRealMichaelWang Nov 7, 2025
24ef52c
* Added support for getting shape attribute of a sparse tensor
TheRealMichaelWang Nov 7, 2025
85d81f7
Implemented direct einsum handling for indirect assignments without r…
TheRealMichaelWang Nov 7, 2025
710a2e3
Refactored indirect einsum handling in the interpreter to support red…
TheRealMichaelWang Nov 7, 2025
1dc9667
Removed implementation of indirect einsum with reduction in the inter…
TheRealMichaelWang Nov 11, 2025
d953757
Enhanced EinsumInterpreter to handle cases with fewer indices than di…
TheRealMichaelWang Nov 11, 2025
09c9879
Renamed test for indirect access to clarify focus on elementwise mult…
TheRealMichaelWang Nov 11, 2025
f65e4f9
Add tests for indirect access in einsum operations
TheRealMichaelWang Nov 11, 2025
7549b62
Refactor EinsumInterpreter to handle flat indexing for 1D arrays and …
TheRealMichaelWang Nov 11, 2025
29f5f03
Refactor EinsumInterpreter to improve handling of indirect indexing a…
TheRealMichaelWang Nov 11, 2025
6fe0528
Remove redundant calculation of parent indices in EinsumInterpreter t…
TheRealMichaelWang Nov 11, 2025
8d0a788
Refactor EinsumInterpreter to enhance indirect indexing logic by sepa…
TheRealMichaelWang Nov 12, 2025
cc09bf5
Refactor EinsumInterpreter to improve index handling by grouping ein.…
TheRealMichaelWang Nov 12, 2025
169cebf
Refactor EinsumInterpreter to enhance index classification and groupi…
TheRealMichaelWang Nov 12, 2025
b2f81e6
* Fixed ruff errors
TheRealMichaelWang Nov 12, 2025
c31f7af
* Fixed mypy issues
TheRealMichaelWang Nov 12, 2025
d6f2d25
* Renamed GetAttribute to GetAttr to be consistent with FinchAssembl…
TheRealMichaelWang Nov 12, 2025
9025903
Merge branch 'main' into mw/add-insum-interpreter-support
TheRealMichaelWang Nov 12, 2025
b59cc0a
Refactor EinsumInterpreter to streamline index evaluation and groupin…
TheRealMichaelWang Nov 12, 2025
5d0b80b
* Fixed ruff errors
TheRealMichaelWang Nov 12, 2025
3fe28f7
Merge branch 'main' into mw/add-insum-interpreter-support
willow-ahrens Nov 12, 2025
b387a0a
Merge branch 'main' of https://github.com/finch-tensor/finch-tensor-l…
TheRealMichaelWang Nov 22, 2025
2a3adec
Merge branch 'mw/add-insum-interpreter-support' of https://github.com…
TheRealMichaelWang Nov 22, 2025
64bfe6e
* Fixed minor issues in einsum node.py to adjust for new base einsum …
TheRealMichaelWang Nov 22, 2025
8e69d16
* Changed indirect access implementation to a recursive approach
TheRealMichaelWang Nov 22, 2025
0e527c7
Rewrote ein.Access implementation to access tensors with a mixture of…
TheRealMichaelWang Nov 23, 2025
f163f65
* Added two tests to einsum tests
TheRealMichaelWang Nov 23, 2025
656aabe
* Added comments to enhance clarity in ein.Access for multiple indicies
TheRealMichaelWang Nov 23, 2025
4bf1a78
* Simplified calculation of dest_axes in ein.Access
TheRealMichaelWang Nov 23, 2025
bd8e22c
* Fixed ruff errors
TheRealMichaelWang Nov 23, 2025
dd0167d
* Fixed end of file issues in .gitignore
TheRealMichaelWang Nov 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/finchlite/finch_einsum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
EinsumExpr,
EinsumNode,
Index,
GetAttribute,
Literal,
Plan,
Produces,
Expand All @@ -25,6 +26,7 @@
"EinsumScheduler",
"EinsumScheduler",
"Index",
"GetAttribute",
"Literal",
"Plan",
"Produces",
Expand Down
151 changes: 146 additions & 5 deletions src/finchlite/finch_einsum/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import numpy as np

from ..algebra import overwrite, promote_max, promote_min
from ..algebra import overwrite, promote_max, promote_min, TensorFType
from . import nodes as ein
from ..symbolic import ftype, gensym


nary_ops = {
operator.add: "add",
Expand Down Expand Up @@ -78,6 +80,11 @@ def __init__(self, xp=None, bindings=None, loops=None):
self.loops = loops

def __call__(self, node):
from ..tensor import (
SparseTensor,
SparseTensorFType,
)

xp = self.xp
match node:
case ein.Literal(val):
Expand All @@ -92,26 +99,118 @@ def __call__(self, node):
func = getattr(xp, nary_ops[func])
vals = [self(arg) for arg in args]
return func(*vals)
case ein.Access(tns, idxs):

#access a tensor with only indices
case ein.Access(tns, idxs) if all(isinstance(idx, ein.Index) for idx in idxs):
assert len(idxs) == len(set(idxs))
assert self.loops is not None

#convert named idxs to positional, integer indices
perm = [idxs.index(idx) for idx in self.loops if idx in idxs]
tns = self(tns)
tns = xp.permute_dims(tns, perm)

tns = self(tns) #evaluate the tensor
tns = xp.permute_dims(tns, perm) #permute the dimensions
return xp.expand_dims(
tns,
[i for i in range(len(self.loops)) if self.loops[i] not in idxs],
)

#access a tensor with only one indirect access index
case ein.Access(tns, idxs):
assert len(idxs) == 1

idx = self(idxs[0])
tns = self(tns) #evaluate the tensor

flat_idx = xp.ravel_multi_index(idx.T, tns.shape)
return tns.flat[flat_idx] #return a 1-d array by definition

#access a tensor with a mixture of indices and other expressions
case ein.Access(tns, idxs):
assert self.loops is not None
true_idxs = node.get_idxs() #true field iteratior indicies
assert all(isinstance(idx, ein.Index) for idx in true_idxs)

# evaluate the tensor to access
tns = self(tns)
assert len(idxs) == len(tns.shape)

# evaluate all the indirect access indicies, and evaluate field indicies as a "grab all"
evaled_idxs = [
(xp.arange(tns.shape[i]) if isinstance(idx, ein.Index) else self(idx))
for i, idx in enumerate(idxs)
]

# evaluate the output tensor as a flat array
flat_idx = xp.ravel_multi_index(xp.vstack(evaled_idxs), tns.shape)
tns = xp.take(tns, flat_idx)

# calculate the final shape of the tensor
# we assert that all the indirect access indicies from the parent idxs have the same size
#calculate child idxs, idxs computed using the parent "true idxs"
child_idxs = {
idx: [
child_idx for child_idx in child_idxs
if (parent_idx in child_idx.get_idxs())
] for parent_idx in true_idxs
}
assert all(
child_idxs[parent_idx].count(child_idxs[parent_idx][0]) == len(child_idxs[parent_idx])
for parent_idx in true_idxs
)
# we merge the child idxs to get the final shape that matches the true idxs
final_shape = tuple(
evaled_idxs[idxs.index(idx)].size
for idx in idxs if idx in true_idxs
)
tns = tns.reshape(final_shape)
idxs = [idx for idx in idxs if idx in true_idxs]

# permute and broadcast the tensor to be compatible with rest of expression
perm = [idxs.index(idx) for idx in self.loops if idx in idxs]
tns = xp.permute_dims(tns, perm)
return xp.expand_dims(
tns,
[i for i in range(len(self.loops)) if self.loops[i] not in idxs]
)

case ein.Plan(bodies):
res = None
for body in bodies:
res = self(body)
return res
case ein.Produces(args):
return tuple(self(arg) for arg in args)
case ein.Einsum(op, ein.Alias(tns), idxs, arg):

#get non-zero elements/data array of a sparse tensor
case ein.GetAttribute(obj, ein.Literal("elems"), _):
obj = self(obj)
assert isinstance(ftype(obj), SparseTensorFType)
assert isinstance(obj, SparseTensor)
return obj.data
#get coord array of a sparse tensor
case ein.GetAttribute(obj, ein.Literal("coords"), dim):
obj = self(obj)
assert isinstance(ftype(obj), SparseTensorFType)
assert isinstance(obj, SparseTensor)

# return the coord array for the given dimension or all dimensions
return obj.coords if dim is None else obj.coords[:, dim]
# gets the shape of a sparse tensor at a given dimension
case ein.GetAttribute(obj, ein.Literal("shape"), dim):
obj = self(obj)
assert isinstance(ftype(obj), SparseTensorFType)
assert isinstance(obj, SparseTensor)
assert dim is not None

# return the shape for the given dimension
return obj.shape[dim]

# standard einsum
case ein.Einsum(op, ein.Alias(tns), idxs, arg) if all(isinstance(idx, ein.Index) for idx in idxs):
# This is the main entry point for einsum execution
loops = arg.get_idxs()

assert set(idxs).issubset(loops)
loops = sorted(loops, key=lambda x: x.name)
ctx = EinsumInterpreter(self.xp, self.bindings, loops)
Expand All @@ -128,5 +227,47 @@ def __call__(self, node):
axis = [dropped.index(idx) for idx in idxs]
self.bindings[tns] = xp.permute_dims(val, axis)
return (tns,)

# indirect einsum with reduction
case ein.Einsum(op, ein.Alias(tns), idxs, arg):
loops = arg.get_idxs()
true_idxs = node.get_idxs()
assert true_idxs.issubset(loops)

# evalaute the tensor to access/write to
# output tensor must exist with initial reduction values
assert tns in self.bindings
tns_out = self.bindings[tns]
assert len(idxs) == len(tns_out.shape)

#Evaluate arg with full loop context
loops_sorted = sorted(loops, key=lambda x: x.name)
ctx = EinsumInterpreter(self.xp, self.bindings, loops_sorted)
arg_val = ctx(arg)

#Reduce over non-true_idxs axes using the reduction op
reduce_axes = tuple(i for i in range(len(loops_sorted))
if loops_sorted[i] not in true_idxs)
if reduce_axes:
op_func = self(op) # Get the actual operation
reduce_func = getattr(xp, reduction_ops[op_func])
reduced_val = reduce_func(arg_val, axis=reduce_axes)
else:
reduced_val = arg_val

# Evaluate indirect index expressions (same context)
evaled_idxs = [
(xp.arange(tns_out.shape[i]) if isinstance(idx, ein.Index) else ctx(idx))
for i, idx in enumerate(idxs)
]
flat_idx = xp.ravel_multi_index(xp.vstack(evaled_idxs), tns_out.shape)

# Scatter/accumulate using the SAME operation's ufunc
scatter_func = getattr(xp, nary_ops[op_func])
scatter_func.at(tns_out.flat, flat_idx, reduced_val)

# assign the output tensor to the bindings
self.bindings[tns] = tns_out
return (tns,)
case _:
raise ValueError(f"Unknown einsum type: {type(node)}")
48 changes: 47 additions & 1 deletion src/finchlite/finch_einsum/nodes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import operator
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Self, cast
from typing import Any, Optional, Self, cast

from finchlite.algebra import (
overwrite,
Expand Down Expand Up @@ -176,6 +176,42 @@ def get_idxs(self) -> set["Index"]:
return idxs


@dataclass(eq=True, frozen=True)
class GetAttribute(EinsumExpr, EinsumTree):
"""
Gets an attribute of a tensor.
Attributes:
obj: The object to get the attribute from.
attr: The name of the attribute to get.

dim: The dimension to get the attribute from.
Note this is an integer index, not a named index.
"""

obj: EinsumExpr
attr: Literal
dim: Optional[int]

@classmethod
def from_children(cls, *children: Term) -> Self:
# Expects 3 children: obj, attr, idx
if len(children) != 3:
raise ValueError("GetAttribute expects 3 children (obj + attr + idx)")
obj = cast(EinsumExpr, children[0])
attr = cast(Literal, children[1])
dim = cast(Optional[int], children[2])
return cls(obj, attr, dim)

@property
def children(self):
return [self.obj, self.attr, self.dim]

def get_idxs(self) -> set["Index"]:
idxs = set()
idxs.update(self.obj.get_idxs())
return idxs


@dataclass(eq=True, frozen=True)
class Einsum(EinsumTree):
"""
Expand Down Expand Up @@ -214,6 +250,12 @@ def from_children(cls, *children: Term) -> Self:
def children(self):
return [self.op, self.tns, self.idxs, self.arg]

def get_idxs(self) -> set["Index"]:
idxs = list()
for idx in self.idxs:
idxs.extend(idx.get_idxs())
return idxs


@dataclass(eq=True, frozen=True)
class Plan(EinsumTree):
Expand Down Expand Up @@ -334,6 +376,10 @@ def __call__(self, prgm: EinsumNode):
if len(args) == 1 and fn.val in unary_strs:
return f"{unary_strs[fn.val]}{args_e[0]}"
return f"{self(fn)}({', '.join(args_e)})"
case GetAttribute(obj, attr, idx):
if idx is not None:
return f"{self(obj)}.{self(attr)}[{self(idx)}]"
return f"{self(obj)}.{self(attr)}"
case Einsum(op, tns, idxs, arg):
op_str = infix_strs.get(op.val, op.val.__name__)
self.exec(
Expand Down
4 changes: 4 additions & 0 deletions src/finchlite/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
dense,
element,
)
from .sparse_tensor import SparseTensor, SparseTensorFType


__all__ = [
"DenseLevel",
Expand All @@ -17,6 +19,8 @@
"FiberTensorFType",
"Level",
"LevelFType",
"SparseTensor",
"SparseTensorFType",
"dense",
"element",
"tensor",
Expand Down
Loading
Loading