Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
d143ce8
* Added einsum lowerer back;
TheRealMichaelWang Oct 9, 2025
7c51829
Fixed some obvious errors
TheRealMichaelWang Oct 9, 2025
d32aa8a
* Added pytests
TheRealMichaelWang Oct 9, 2025
2524294
* Fixed error in EinsumInterpreter with handling return values from e…
TheRealMichaelWang Oct 10, 2025
f0f2112
* Addressed issue #209 by properly initializing Einsum with return va…
TheRealMichaelWang Oct 10, 2025
6da9ff7
* Addressed issue #209 by removing redudant Produce IR node
TheRealMichaelWang Oct 10, 2025
4969b25
* Added more pytests
TheRealMichaelWang Oct 10, 2025
f01df02
Merge branch 'main' into add-einsum-lowerer
TheRealMichaelWang Oct 10, 2025
b80fb7c
* Fixed ruff errors
TheRealMichaelWang Oct 10, 2025
6ab474f
* Fixed more ruff errors
TheRealMichaelWang Oct 10, 2025
6358d81
* Added support to EinsumPrinterContext to print Plan return values
TheRealMichaelWang Oct 10, 2025
2f6c8f8
* Ran ruff format
TheRealMichaelWang Oct 10, 2025
21213c5
* Fixed mypy type errors
TheRealMichaelWang Oct 10, 2025
5f36539
Still have to invoke tns.to_numpy; added type safe attribute checking…
TheRealMichaelWang Oct 10, 2025
c4e2eda
* Ran ruff
TheRealMichaelWang Oct 10, 2025
9be7e18
* Restored produce einsum ir node
TheRealMichaelWang Oct 14, 2025
4f7446c
* Removed return values from Einsum IR Node
TheRealMichaelWang Oct 18, 2025
a02ffcd
* Fixed issues with einsum printer.
TheRealMichaelWang Oct 18, 2025
c81a1b6
* Fixed type errors in einsum lowerer
TheRealMichaelWang Oct 18, 2025
67b20ea
Merge branch 'main' into add-einsum-lowerer
TheRealMichaelWang Oct 18, 2025
fad5a31
* Fixed ruff errors
TheRealMichaelWang Oct 18, 2025
8e3b5c7
* Added sparse tensor implementation
TheRealMichaelWang Oct 20, 2025
0115285
* Added framework for Insum Lowerer
TheRealMichaelWang Oct 20, 2025
ef39345
* Renamed parameters to binding in einsum lowerer to stay consitent w…
TheRealMichaelWang Oct 20, 2025
aa769d7
Merge branch 'add-einsum-lowerer' of https://github.com/finch-tensor/…
TheRealMichaelWang Oct 20, 2025
ba0992d
* Renamed parameters to bindings in InsumLowerer for consitency with …
TheRealMichaelWang Oct 20, 2025
a372a0d
* Implemented can_optimize method in InsumLowerer
TheRealMichaelWang Oct 20, 2025
d99b571
Impelemented get_sparse_params in InsumLowerer
TheRealMichaelWang Oct 20, 2025
56333d7
* Added GetAttribute Einsum IR node
TheRealMichaelWang Oct 20, 2025
0e2645e
* Implemented to_insum method in InsumLowerer
TheRealMichaelWang Oct 21, 2025
79fda71
* Added top level optimize plan method to insum lowerer
TheRealMichaelWang Oct 21, 2025
d3cf66f
* Added barebones test utilities to insum lowerer pytest.
TheRealMichaelWang Oct 21, 2025
92f59b5
* Added support for GetAttribute EinsumExpr in EinsumPrinterContext
TheRealMichaelWang Oct 21, 2025
e2c9f78
Merge branch 'main' of https://github.com/finch-tensor/finch-tensor-l…
TheRealMichaelWang Oct 21, 2025
f24b43a
* Simplified Einsum Lowerer by inlining lower_to_einsum into compile_…
TheRealMichaelWang Oct 21, 2025
725040e
Refactored EinsumLowerer to streamline aggregate handling and removed…
TheRealMichaelWang Oct 21, 2025
0c26cc5
Enhanced EinsumLowerer to support Reformat cases in MapJoin and Aggre…
TheRealMichaelWang Oct 21, 2025
ade7292
* Ran ruff check and ruff format
TheRealMichaelWang Oct 21, 2025
426e12c
* Moved mapjoin and aggregate compilation into seperate functions to …
TheRealMichaelWang Oct 21, 2025
c94f2d0
* Fixed mypy typing errors
TheRealMichaelWang Oct 21, 2025
2db61ac
* Fixed more mypy errors
TheRealMichaelWang Oct 21, 2025
7f10097
* Finally fixed all mypy errors
TheRealMichaelWang Oct 21, 2025
6aaceee
* Ran ruff
TheRealMichaelWang Oct 21, 2025
b377748
* Undid changes to effectively unchanged files
TheRealMichaelWang Oct 21, 2025
fe29f83
Refactored EinsumLowerer by removing the compile_aggregate method and…
TheRealMichaelWang Oct 22, 2025
89557e2
Refactored EinsumLowerer to consolidate mapjoin handling into a singl…
TheRealMichaelWang Oct 22, 2025
d34b5b3
Refactored EinsumLowerer to streamline Aggregate case handling by con…
TheRealMichaelWang Oct 22, 2025
07c4903
* Ran pytests, ruff, and mypy
TheRealMichaelWang Oct 22, 2025
0ebaa91
suggested changes for clarity, more to come
willow-ahrens Oct 24, 2025
97a815e
Revert "suggested changes for clarity, more to come"
TheRealMichaelWang Oct 27, 2025
523e135
* Removed unecessary imports from logic in einsum lowerer
TheRealMichaelWang Oct 27, 2025
cbf5f43
* Removed unecessary cases, all of which are optimized out
TheRealMichaelWang Oct 27, 2025
87ba773
* Added case for logical reformat AST node in compile operand
TheRealMichaelWang Oct 27, 2025
e0095f0
Inlined compile_expression method
TheRealMichaelWang Oct 27, 2025
4583f9c
* Removed unused parameters from compile operand
TheRealMichaelWang Oct 27, 2025
8ca2eb7
* Fixed potential mypy error in EinsumLowerer compile_operand
TheRealMichaelWang Oct 27, 2025
08dabeb
* Fixed ruff errors
TheRealMichaelWang Oct 27, 2025
140051c
* Removed flatten args
TheRealMichaelWang Oct 27, 2025
42eed92
Refactored compile_plan as a fucntion that only returns an EinsumNode…
TheRealMichaelWang Oct 27, 2025
874c76d
* Fixed potential mypy errors
TheRealMichaelWang Oct 27, 2025
4c4813c
* Fixed ruff errors
TheRealMichaelWang Oct 27, 2025
56fb9ee
Merge branch 'add-einsum-lowerer' of https://github.com/finch-tensor/…
TheRealMichaelWang Oct 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/finchlite/autoschedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
)
from ..symbolic import PostOrderDFS, PostWalk, PreWalk
from .compiler import LogicCompiler
from .einsum import EinsumLowerer
from .insum import InsumLowerer
from .optimize import (
DefaultLogicOptimizer,
concordize,
Expand Down Expand Up @@ -43,6 +45,8 @@
"Aggregate",
"Alias",
"DefaultLogicOptimizer",
"EinsumLowerer",
"InsumLowerer",
"Field",
"Literal",
"LogicCompiler",
Expand Down
97 changes: 97 additions & 0 deletions src/finchlite/autoschedule/einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import Any, cast

import finchlite.finch_einsum as ein
import finchlite.finch_logic as lgc
from finchlite.algebra import init_value, overwrite


class EinsumLowerer:
def __call__(self, prgm: lgc.Plan) -> tuple[ein.Plan, dict[str, Any]]:
bindings: dict[str, Any] = {}
definitions: dict[str, ein.Einsum] = {}
return cast(ein.Plan, self.compile_plan(prgm, bindings, definitions)), bindings

def compile_plan(
self,
node: lgc.LogicNode,
bindings: dict[str, Any],
definitions: dict[str, ein.Einsum],
) -> ein.EinsumNode | None:
match node:
case lgc.Plan(bodies):
ein_bodies = [
self.compile_plan(body, bindings, definitions) for body in bodies
]
not_none_bodies = [body for body in ein_bodies if body is not None]
return ein.Plan(tuple(not_none_bodies))
case lgc.Query(lgc.Alias(name), lgc.Table(lgc.Literal(val), _)):
bindings[name] = val
return None
case lgc.Query(
lgc.Alias(name),
lgc.Aggregate(lgc.Literal(operation), lgc.Literal(init), arg, _),
):
einidxs = tuple(ein.Index(field.name) for field in node.rhs.fields)
my_bodies = []
if init != init_value(operation, type(init)):
my_bodies.append(
ein.Einsum(
op=ein.Literal(overwrite),
tns=ein.Alias(name),
idxs=einidxs,
arg=ein.Literal(init),
)
)
my_bodies.append(
ein.Einsum(
op=ein.Literal(operation),
tns=ein.Alias(name),
idxs=einidxs,
arg=self.compile_operand(arg),
)
)
return ein.Plan(tuple(my_bodies))
case lgc.Query(lgc.Alias(name), rhs):
einarg = self.compile_operand(rhs)
return ein.Einsum(
op=ein.Literal(overwrite),
tns=ein.Alias(name),
idxs=tuple(ein.Index(field.name) for field in node.rhs.fields),
arg=einarg,
)

case lgc.Produces(args):
returnValues = []
for ret_arg in args:
if not isinstance(ret_arg, lgc.Alias):
raise Exception(f"Unrecognized logic: {ret_arg}")
returnValues.append(ein.Alias(ret_arg.name))

return ein.Produces(tuple(returnValues))
case _:
raise Exception(f"Unrecognized logic: {node}")

# lowers nested mapjoin logic IR nodes into a single pointwise expression
def compile_operand(
self,
ex: lgc.LogicNode,
) -> ein.EinsumExpr:
match ex:
case lgc.Reformat(_, rhs):
return self.compile_operand(rhs)
case lgc.Reorder(arg, idxs):
return self.compile_operand(arg)
case lgc.MapJoin(lgc.Literal(operation), lgcargs):
args = tuple([self.compile_operand(arg) for arg in lgcargs])
return ein.Call(ein.Literal(operation), args)
case lgc.Relabel(
lgc.Alias(name), idxs
): # relable is really just a glorified pointwise access
return ein.Access(
tns=ein.Alias(name),
idxs=tuple(ein.Index(idx.name) for idx in idxs),
)
case lgc.Literal(value):
return ein.Literal(val=value)
case _:
raise Exception(f"Unrecognized logic: {ex}")
232 changes: 232 additions & 0 deletions src/finchlite/autoschedule/insum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import operator
from typing import Any, cast

import finchlite.finch_einsum as ein
import finchlite.finch_logic as logic
from finchlite.symbolic import (
ftype,
PostWalk,
Rewrite,
gensym
)
from finchlite.algebra import (
overwrite,
init_value,
ifelse
)
from finchlite.autoschedule import (
EinsumLowerer
)
from finchlite.tensor import (
SparseTensorFType
)

class InsumLowerer:
def __init__(self):
self.el = EinsumLowerer()

def can_optimize(self, en: ein.EinsumNode, sparse: set[str]) -> tuple[bool, dict[str, tuple[ein.Index, ...]]]:
"""
Checks if an einsum node can be optimized via indirect einsums.
Specifically it checks whether node is an einsum that references any sparse tensor binding/parameter.

Arguments:
en: The einsum node to check.
sparse: The set of aliases of sparse tensor bindings/parameters.

Returns:
A tuple containing:
- A boolean indicating if the einsum node can be optimized.
- A dictionary mapping sparse binding aliases to the indices they are referenced with.
"""
if not isinstance(en, ein.Einsum):
return False

einsum = cast(ein.Einsum, en)

refed_sparse = dict()

def sparse_detect(node: ein.EinsumExpr):
nonlocal refed_sparse

match node:
case ein.Access(ein.Alias(name), idxs):
if name not in sparse:
return None

if name in refed_sparse and refed_sparse[name] != idxs:
raise ValueError(
f"Sparse binding {name} is being referenced "
"with different indicies.")
refed_sparse[name] = idxs
return None

PostWalk(sparse_detect)(einsum.arg)
return len(refed_sparse) > 0, refed_sparse

def to_insum(self, einsum: ein.Einsum, sparse: str, sparse_idxs: tuple[ein.Index, ...]) -> list[ein.EinsumNode]:
bodies: list[ein.EinsumNode] = []
reduced_idx = ein.Index(gensym(f"pos"))
# initialize mask tensor T which is a boolean that represents whether each reduced fiber in the sparse tensor has non-zero elements or not
# Essentially T[idxs...] = whether the sparse tensor fiber being reduced at idxs... has any non-zero elements in it
T_idxs = tuple(idx for idx in einsum.idxs if idx in sparse_idxs)
T_mask = ein.Alias(gensym(f"{sparse}_T"))
bodies.append(ein.Einsum( #initialize every element of T to 0
op=ein.Literal(overwrite),
alias=T_mask,
idxs=T_idxs,
arg=ein.Literal(0)
))
bodies.append(ein.Einsum(
op=ein.Literal(operator.add),
alias=T_mask,
idxs=(
ein.Access(
ein.GetAttribute(
obj=ein.Alias(sparse),
attr=ein.Literal("coords"),
idx=None
),
(reduced_idx,)
),
),
arg=ein.Literal(1)
))

# get the reduced indicies in the sparse tensor
reduced_idxs = tuple(idx for idx in einsum.idxs if idx not in sparse_idxs)

# get the size of the fiber in the sparse tensor being reduced
reduced_fiber_size = ein.Call(ein.Literal(operator.mul), (
ein.Literal(1),
*[ein.GetAttribute(
obj=ein.Alias(sparse),
attr=ein.Literal("shape"),
idx=idx
) for idx in reduced_idxs]
))

# rewrite the indicies used to iterate over the sparse tensor
def rewrite_indicies(idxs: tuple[ein.EinsumExpr, ...]) -> tuple[ein.EinsumExpr, ...]:
if idxs == sparse_idxs:
return (ein.Access(
ein.GetAttribute(
obj=ein.Alias(sparse),
attr=ein.Literal("coords"),
idx=None
),
(reduced_idx,)
),)

new_idxs = []
for idx in idxs:
match idx:
case ein.Index(_) if idx in sparse_idxs:
new_idxs.append(ein.Access(
ein.GetAttribute(
obj=ein.Alias(sparse),
attr=ein.Literal("coords"),
idx=idx
),
(reduced_idx,)
))
case _:
new_idxs.append(idx)
return tuple(new_idxs)

# pattern matching rule to rewrite all indicies in arg
def rewrite_all_indicies(node: ein.EinsumExpr) -> ein.EinsumExpr:
match node:
case ein.Access(ein.Alias(name), idxs) if name == sparse and idxs == sparse_idxs:
return ein.Access(
ein.GetAttribute(
obj=ein.Alias(sparse),
attr=ein.Literal("elems"),
idx=None
),
(reduced_idx,)
)
case ein.Access(ein.Alias(name), idxs):
return ein.Access(ein.Alias(name), rewrite_indicies(idxs))

# rewrite a pointwise expression to assume that the sparse tensor is all-zero
def rewrite_zero(node: ein.EinsumExpr) -> ein.EinsumExpr:
match node:
case ein.Access(ein.Alias(name), _) if name == sparse:
return ein.Literal(0)
case ein.Access(ein.GetAttribute(ein.Alias(name), ein.Literal("elems"), None), _) if name == sparse:
return ein.Literal(0)

# rewrite
new_einarg = Rewrite(PostWalk(rewrite_all_indicies))(einsum.arg)
zero_einarg = Rewrite(PostWalk(rewrite_zero))(einsum.arg)

# initialize the reduction values
# essentially, we calculate the reduction values for the reduced fibers of the sparse tensor that are non zero, and hence who's iterations asre skipped
# we make the following core assumption: that the reduction operator, $f$ is associative and commutative.
# In other words, $f(a, f(b, c)) = f(f(a, b), c)$ for all $a, b, c$.
# In essence we assume a single zero element combined with the initial value passed through the reduction operator will
# be equal to the effect of one or more zero elements at any point in the reduced fiber combined with the initial value.
init = 0 if einsum.op == overwrite else init_value(einsum.op, type(0))
bodies.append(ein.Einsum(
op=ein.Literal(overwrite),
alias=einsum.alias,
idxs=einsum.idxs,
arg=ein.Call(ein.Literal(ifelse), (
ein.Call(ein.Literal(operator.eq), ( #check if T[idxs...] == reduced_fiber_size
ein.Access(T_mask, (reduced_idx,)),
reduced_fiber_size
)),
init, # if fiber is all non-zero initial reduction value is default
ein.Call(ein.Literal(einsum.op), (
ein.Literal(init),
zero_einarg
))
))
))

#finally we execute the naive einsum -> insum
bodies.append(ein.Einsum(
op=einsum.op,
alias=einsum.alias,
idxs=rewrite_indicies(einsum.idxs),
arg=new_einarg
))

return bodies

def get_sparse_params(self, bindings: dict[str, Any]) -> set[str]:
"""
Gets the set of sparse binding aliases from the bindings dictionary.

Arguments:
bindings: The bindings dictionary.

Returns:
A set of sparse binding aliases.
"""

sparse = set()

for alias, value in bindings.items():
match value:
case logic.Table(logic.Literal(tensor_value), _):
if isinstance(ftype(tensor_value), SparseTensorFType):
sparse.add(alias)

return sparse

def optimize_plan(self, plan: ein.Plan, bindings: dict[str, Any]) -> tuple[ein.Plan, dict[str, Any]]:
sparse = self.get_sparse_params(bindings)

new_bodies = []
for body in plan.bodies:
can_optimize, all_sparse = self.can_optimize(body, sparse)
if can_optimize:
sparse_binding, sparse_idxs = next(iter(all_sparse))
new_bodies.extend(self.to_insum(body, sparse_binding, sparse_idxs))
else:
new_bodies.append(body)

return ein.Plan(new_bodies), bindings

Loading
Loading