From a534fc8bbbb5862354e433184c948c8af0b678d2 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 8 Sep 2025 09:03:24 -0400 Subject: [PATCH 01/26] * Added support for printing lowered logic IR to better understand IR * Added EinsumExtractor barebones and Einsum and Einprod --- src/finchlite/autoschedule/__init__.py | 10 +++ src/finchlite/autoschedule/einsum.py | 95 ++++++++++++++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 src/finchlite/autoschedule/einsum.py diff --git a/src/finchlite/autoschedule/__init__.py b/src/finchlite/autoschedule/__init__.py index 32022d8e..b8530d07 100644 --- a/src/finchlite/autoschedule/__init__.py +++ b/src/finchlite/autoschedule/__init__.py @@ -38,11 +38,20 @@ push_fields, set_loop_order, ) +from .einsum import ( + Einsum, + Einprod, + EinsumTransformer, + PrintingLogicOptimizer +) __all__ = [ "Aggregate", "Alias", "DefaultLogicOptimizer", + "Einsum", + "Einprod", + "EinsumTransformer", "Field", "Literal", "LogicCompiler", @@ -51,6 +60,7 @@ "PostOrderDFS", "PostWalk", "PreWalk", + "PrintingLogicOptimizer", "Produces", "Query", "Reformat", diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py new file mode 100644 index 00000000..820d24cb --- /dev/null +++ b/src/finchlite/autoschedule/einsum.py @@ -0,0 +1,95 @@ +from finchlite.finch_logic import LogicTree, LogicExpression, MapJoin, Aggregate, LogicNode, Alias, Table, Literal +from finchlite.autoschedule import optimize, DefaultLogicOptimizer, LogicCompiler +from finchlite.symbolic import PostOrderDFS + +class Einsum(LogicTree, LogicExpression): + pass + +class Einprod(LogicTree, LogicExpression): + pass + +class EinsumTransformer(DefaultLogicOptimizer): + """Transforms program into Einsum and Einprod""" + + def __init__(self, ctx: LogicCompiler, verbose=True): + super().__init__(ctx) + self.verbose = verbose + + def __call__(self, prgm: LogicNode): + # First optimize the program + prgm = optimize(prgm) + + return self.ctx(prgm) + + def transform(self, prgm: LogicNode): + pass + + +class PrintingLogicOptimizer(DefaultLogicOptimizer): + """Custom optimizer that prints MapJoin and Aggregate operations""" + + def __init__(self, ctx: LogicCompiler, verbose=True): + super().__init__(ctx) + self.verbose = verbose + self.operation_count = {"MapJoin": 0, "Aggregate": 0} + + def __call__(self, prgm: LogicNode): + # First optimize the program + prgm = optimize(prgm) + + # Then traverse and print all MapJoin/Aggregate operations + if self.verbose: + print("\n=== Finch Logic IR Operations ===") + self._print_operations(prgm) + print(f"\nTotal MapJoins: {self.operation_count['MapJoin']}") + print(f"Total Aggregates: {self.operation_count['Aggregate']}") + print("================================\n") + + # Continue with compilation + return self.ctx(prgm) + + def _print_operations(self, node): + """Traverse the Logic IR and print MapJoin/Aggregate operations""" + for n in PostOrderDFS(node): + match n: + case MapJoin(op, args): + self.operation_count["MapJoin"] += 1 + print(f"\nMapJoin #{self.operation_count['MapJoin']}:") + print(f" Operation: {self._format_op(op)}") + print(f" Args: {self._format_args(args)}") + print(f" Fields: {n.fields}") + + case Aggregate(op, init, arg, idxs): + self.operation_count["Aggregate"] += 1 + print(f"\nAggregate #{self.operation_count['Aggregate']}:") + print(f" Operation: {self._format_op(op)}") + print(f" Init: {self._format_literal(init)}") + print(f" Reduce dims: {idxs}") + print(f" Input fields: {arg.fields if hasattr(arg, 'fields') else 'N/A'}") + print(f" Output fields: {n.fields}") + + def _format_op(self, op): + """Format operation for printing""" + if isinstance(op, Literal): + if hasattr(op.val, '__name__'): + return op.val.__name__ + return str(op.val) + return str(op) + + def _format_literal(self, lit): + """Format literal for printing""" + if isinstance(lit, Literal): + return str(lit.val) + return str(lit) + + def _format_args(self, args): + """Format arguments for printing""" + formatted = [] + for arg in args: + if isinstance(arg, Alias): + formatted.append(f"Alias({arg.name})") + elif isinstance(arg, Table): + formatted.append(f"Table(...)") + else: + formatted.append(type(arg).__name__) + return formatted \ No newline at end of file From e98719a859de35872acbe1ca314a62331f6f6df7 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 8 Sep 2025 10:36:33 -0400 Subject: [PATCH 02/26] - Fully implemented Einsum logic node - Implemented EInsumTransformer transofrm which transforms aggregates and map joins into einsums - Not comprehensive --- src/finchlite/autoschedule/einsum.py | 74 ++++++++++++++++++++++++---- 1 file changed, 64 insertions(+), 10 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 820d24cb..43213276 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -1,28 +1,82 @@ -from finchlite.finch_logic import LogicTree, LogicExpression, MapJoin, Aggregate, LogicNode, Alias, Table, Literal -from finchlite.autoschedule import optimize, DefaultLogicOptimizer, LogicCompiler -from finchlite.symbolic import PostOrderDFS +from dataclasses import dataclass +import operator +from finchlite.finch_logic import LogicTree, LogicExpression, MapJoin, Aggregate, LogicNode, Alias, Table, Literal, Field +from finchlite.autoschedule import DefaultLogicOptimizer, LogicCompiler, optimize +from finchlite.symbolic import Rewrite, PostWalk, PostOrderDFS + +@dataclass(eq=True, frozen=False) class Einsum(LogicTree, LogicExpression): - pass + """ + NumPy-style einsum logic node. + + - inputs: per-argument axis labels as Fields, e.g., ((i,k), (k,j)) + - output: output axis labels as Fields, e.g., (i,j) + - args: input expressions + """ + + inputs: tuple[tuple[Field, ...], ...] + outputs: tuple[Field, ...] + args: tuple[LogicExpression, ...] + + def __init__(self, args: tuple[LogicExpression, ...], inputs: tuple[tuple[Field, ...], ...] | None = None, outputs: tuple[Field, ...] | None = None): + self.args = args + + self.inputs = inputs + if inputs is None: # If inputs are not provided, compute them from the arguments + self.inputs = tuple(tuple(f for f in arg.fields) for arg in args) + + union_fields: list[Field] = outputs if outputs is not None else inputs + for labels in self.inputs: + for f in labels: + if f not in union_fields: + union_fields.append(f) + self.outputs = tuple(union_fields) + + @property + def children(self): + # Treat only args as children in the term tree + return list(self.args) -class Einprod(LogicTree, LogicExpression): - pass + @property + def fields(self) -> list[Field]: + return list(self.outputs) class EinsumTransformer(DefaultLogicOptimizer): - """Transforms program into Einsum and Einprod""" + """ + Rewrite unoptimized Logic IR (mostly MapJoin and Aggregate) into Einsum nodes. + + Pattern handled: + - Aggregate(add, 0, MapJoin(mul, args), reduce_idxs) -> Einsum(inputs, output, args) + - MapJoin(mul, args) -> Einsum(inputs, output, args) + + After rewriting, Einsum nodes are lowered back to MapJoin/Aggregate for compilation. + """ def __init__(self, ctx: LogicCompiler, verbose=True): super().__init__(ctx) self.verbose = verbose def __call__(self, prgm: LogicNode): - # First optimize the program prgm = optimize(prgm) + transformed = self.transform(prgm) - return self.ctx(prgm) + return transformed def transform(self, prgm: LogicNode): - pass + def rule(node): + match node: + # Sum over product -> Einsum + case Aggregate(Literal(op_add), Literal(init), MapJoin(Literal(op_mul), args), idxs): + if op_add is operator.add and init == 0: + return Einsum(args=args, inputs=None, outputs=idxs) + + # Pure elementwise product -> Einsum (no contraction) + case MapJoin(Literal(op_mul), args): + if op_mul is operator.mul: + return Einsum(args=args, inputs=None, output=None) + + return Rewrite(PostWalk(rule))(prgm) class PrintingLogicOptimizer(DefaultLogicOptimizer): From 47aedc88b896aabde95d29fb517babfa9fe72c0c Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 8 Sep 2025 10:47:53 -0400 Subject: [PATCH 03/26] * Simplified union calculation for output fields --- src/finchlite/autoschedule/einsum.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 43213276..d17afa98 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -22,16 +22,16 @@ class Einsum(LogicTree, LogicExpression): def __init__(self, args: tuple[LogicExpression, ...], inputs: tuple[tuple[Field, ...], ...] | None = None, outputs: tuple[Field, ...] | None = None): self.args = args - self.inputs = inputs - if inputs is None: # If inputs are not provided, compute them from the arguments - self.inputs = tuple(tuple(f for f in arg.fields) for arg in args) + #inputs are the fields of the arguments by default + self.inputs = inputs if inputs is not None else tuple(tuple(f for f in arg.fields) for arg in args) - union_fields: list[Field] = outputs if outputs is not None else inputs + #outputs are the union of the inputs by default, or the union of the outputs if provided + union_fields: list[Field] = outputs if outputs is not None else inputs #union fields are inputs by default for labels in self.inputs: for f in labels: if f not in union_fields: union_fields.append(f) - self.outputs = tuple(union_fields) + self.outputs = tuple(union_fields) #outputs are simply the union of the union fields @property def children(self): From 5905b3b06f5f5cb4843ebe899c61b644d9b0deb6 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 8 Sep 2025 10:57:06 -0400 Subject: [PATCH 04/26] * Added support for to string to einsum node --- src/finchlite/autoschedule/einsum.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index d17afa98..83332268 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -42,6 +42,9 @@ def children(self): def fields(self) -> list[Field]: return list(self.outputs) + def to_string(self) -> str: + return f"np.einsum(\"{','.join(self.inputs)}->{','.join(self.outputs)}\", {','.join(self.args)})" + class EinsumTransformer(DefaultLogicOptimizer): """ Rewrite unoptimized Logic IR (mostly MapJoin and Aggregate) into Einsum nodes. From fb0c097c5c7e2cef983353bbfd6237bfb7188ad6 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 8 Sep 2025 11:43:59 -0400 Subject: [PATCH 05/26] - Added more patterns to einsum transformer --- src/finchlite/autoschedule/einsum.py | 59 ++++++++++++++++++++++------ 1 file changed, 47 insertions(+), 12 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 83332268..04d6a6d4 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -1,7 +1,7 @@ from dataclasses import dataclass import operator -from finchlite.finch_logic import LogicTree, LogicExpression, MapJoin, Aggregate, LogicNode, Alias, Table, Literal, Field +from finchlite.finch_logic import LogicTree, LogicExpression, MapJoin, Aggregate, LogicNode, Alias, Table, Literal, Field, Relabel, Reorder from finchlite.autoschedule import DefaultLogicOptimizer, LogicCompiler, optimize from finchlite.symbolic import Rewrite, PostWalk, PostOrderDFS @@ -15,12 +15,15 @@ class Einsum(LogicTree, LogicExpression): - args: input expressions """ + isEinProduct: bool + inputs: tuple[tuple[Field, ...], ...] outputs: tuple[Field, ...] args: tuple[LogicExpression, ...] - def __init__(self, args: tuple[LogicExpression, ...], inputs: tuple[tuple[Field, ...], ...] | None = None, outputs: tuple[Field, ...] | None = None): + def __init__(self, args: tuple[LogicExpression, ...], inputs: tuple[tuple[Field, ...], ...] | None = None, outputs: tuple[Field, ...] | None = None, isEinProduct: bool = False): self.args = args + self.isEinProduct = isEinProduct #not an einsum but an ein-product #inputs are the fields of the arguments by default self.inputs = inputs if inputs is not None else tuple(tuple(f for f in arg.fields) for arg in args) @@ -50,10 +53,16 @@ class EinsumTransformer(DefaultLogicOptimizer): Rewrite unoptimized Logic IR (mostly MapJoin and Aggregate) into Einsum nodes. Pattern handled: - - Aggregate(add, 0, MapJoin(mul, args), reduce_idxs) -> Einsum(inputs, output, args) - - MapJoin(mul, args) -> Einsum(inputs, output, args) + - Aggregate(add, 0, MapJoin(mul, args), reduce_idxs) -> Einsum(args, outputs=reduce_idxs) + - Aggregate(add, 0, Relabel(MapJoin(mul, args), _), reduce_idxs)-> Einsum(args, outputs=reduce_idxs) + - Aggregate(add, 0, Reorder(MapJoin(mul, args), _), reduce_idxs)-> Einsum(args, outputs=reduce_idxs) + - Aggregate(add, 0, Relabel(Reorder(MapJoin(mul, args), _), _), reduce_idxs)-> Einsum(args, outputs=reduce_idxs) + - Aggregate(add, 0, Reorder(Relabel(MapJoin(mul, args), _), _), reduce_idxs)-> Einsum(args, outputs=reduce_idxs) + - Aggregate(add, 0, Einsum(...), reduce_idxs)-> Einsum(..., outputs=reduce_idxs) + + - MapJoin(mul, args)-> Einsum(args) #elementwise product (no contraction) + - Aggregate(mul, 0, MapJoin(mul, args), reduce_idxs)-> Einsum(args, outputs=reduce_idxs, isEinProduct=True) #elementwise product (no contraction) - After rewriting, Einsum nodes are lowered back to MapJoin/Aggregate for compilation. """ def __init__(self, ctx: LogicCompiler, verbose=True): @@ -69,15 +78,41 @@ def __call__(self, prgm: LogicNode): def transform(self, prgm: LogicNode): def rule(node): match node: - # Sum over product -> Einsum - case Aggregate(Literal(op_add), Literal(init), MapJoin(Literal(op_mul), args), idxs): - if op_add is operator.add and init == 0: - return Einsum(args=args, inputs=None, outputs=idxs) + # Sum over product with harmless wrappers -> Einsum + case Aggregate(Literal(operator.add), Literal(0), Relabel(MapJoin(Literal(operator.mul), args), _), idxs): + return Einsum(args=args, inputs=None, outputs=idxs) + case Aggregate(Literal(operator.add), Literal(0), Reorder(MapJoin(Literal(operator.mul), args), _), idxs): + return Einsum(args=args, inputs=None, outputs=idxs) + case Aggregate( + Literal(operator.add), + Literal(0), + Relabel(Reorder(MapJoin(Literal(operator.mul), args), _), _), + idxs, + ): + return Einsum(args=args, inputs=None, outputs=idxs) + case Aggregate( + Literal(operator.add), + Literal(0), + Reorder(Relabel(MapJoin(Literal(operator.mul), args), _), _), + idxs, + ): + return Einsum(args=args, inputs=None, outputs=idxs) + + # Sum over already-converted Einsum (e.g., MapJoin->Einsum happened earlier) + case Aggregate(Literal(operator.add), Literal(0), Einsum(args=args, inputs=_, outputs=_), idxs): + return Einsum(args=args, inputs=None, outputs=idxs) + + # Original core pattern matching rules + # Sum over product -> Einsum(no contraction) + case Aggregate(Literal(operator.add), Literal(0), MapJoin(Literal(operator.mul), args), idxs): + return Einsum(args=args, inputs=None, outputs=idxs) + # Sum over product -> Einsum (no contraction) + case Aggregate(Literal(operator.mul), Literal(0), MapJoin(Literal(operator.mul), args), idxs): + return Einsum(args=args, inputs=None, outputs=idxs, isEinProduct=True) # Pure elementwise product -> Einsum (no contraction) - case MapJoin(Literal(op_mul), args): - if op_mul is operator.mul: - return Einsum(args=args, inputs=None, output=None) + case MapJoin(Literal(operator.mul), args): + return Einsum(args=args, inputs=None, outputs=None) return Rewrite(PostWalk(rule))(prgm) From a6470a66daa41baf2260438d9a10b5abbfc8cd27 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Thu, 11 Sep 2025 12:28:36 -0400 Subject: [PATCH 06/26] * Removed einprod * Added barebones pointwise AST --- src/finchlite/autoschedule/__init__.py | 2 - src/finchlite/autoschedule/einsum.py | 131 +++++++++++++++++++------ 2 files changed, 101 insertions(+), 32 deletions(-) diff --git a/src/finchlite/autoschedule/__init__.py b/src/finchlite/autoschedule/__init__.py index b8530d07..5e692653 100644 --- a/src/finchlite/autoschedule/__init__.py +++ b/src/finchlite/autoschedule/__init__.py @@ -40,7 +40,6 @@ ) from .einsum import ( Einsum, - Einprod, EinsumTransformer, PrintingLogicOptimizer ) @@ -50,7 +49,6 @@ "Alias", "DefaultLogicOptimizer", "Einsum", - "Einprod", "EinsumTransformer", "Field", "Literal", diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 04d6a6d4..d1bb4d69 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -1,12 +1,86 @@ from dataclasses import dataclass +from abc import ABC +from typing import Callable, Self import operator -from finchlite.finch_logic import LogicTree, LogicExpression, MapJoin, Aggregate, LogicNode, Alias, Table, Literal, Field, Relabel, Reorder +from finchlite.finch_logic import LogicExpression, MapJoin, Aggregate, LogicNode, Alias, Table, Literal, Field, Relabel, Reorder from finchlite.autoschedule import DefaultLogicOptimizer, LogicCompiler, optimize -from finchlite.symbolic import Rewrite, PostWalk, PostOrderDFS +from finchlite.symbolic import Rewrite, PostWalk, PostOrderDFS, Term, TermTree -@dataclass(eq=True, frozen=False) -class Einsum(LogicTree, LogicExpression): + +@dataclass(eq=True, frozen=True) +class PointwiseNode(Term, ABC): + """ + PointwiseNode + + Represents an AST node in the Einsum Pointwise Expression IR + """ + + @classmethod + def head(cls): + """Returns the head of the node.""" + return cls + + @classmethod + def make_term(cls, head, *children: Term) -> Self: + return head.from_children(*children) + + @classmethod + def from_children(cls, *children: Term) -> Self: + return cls(*children) + +@dataclass(eq=True, frozen=True) +class PointwiseAccess(PointwiseNode, TermTree): + """Tensor access like a[i, j].""" + tensor: LogicExpression + idxs: tuple[Field, ...] # (Field('i'), Field('j')) + # Children: None (leaf) + + @classmethod + def from_children(cls, tensor: LogicExpression, idxs: tuple[Field, ...]) -> Self: + return cls(tensor, idxs) + + @property + def children(self): + return [self.tensor, *self.idxs] + +@dataclass(eq=True, frozen=True) +class PointwiseOp(PointwiseNode): + """Operation like + or *.""" + op: Callable #the function to apply e.g., operator.add + args: tuple[PointwiseNode, ...] # Subtrees + # Children: The args + + @classmethod + def from_children(cls, op: Callable, args: tuple[PointwiseNode, ...]) -> Self: + return cls(op, args) + + @property + def children(self): + return [self.op, *self.args] + +@dataclass(eq=True, frozen=True) +class PointwiseLiteral(PointwiseNode): + """ + PointwiseLiteral + + A scalar literal/value for pointwise operations. + """ + + val: float + + @classmethod + def from_children(cls, val: float) -> Self: + return cls(val) + + @property + def children(self): + return [self.val] + +#einsum and einsum ast not part of logic IR +#transform to it's own language +@dataclass(eq=True, frozen=True) +class Einsum: """ NumPy-style einsum logic node. @@ -21,21 +95,6 @@ class Einsum(LogicTree, LogicExpression): outputs: tuple[Field, ...] args: tuple[LogicExpression, ...] - def __init__(self, args: tuple[LogicExpression, ...], inputs: tuple[tuple[Field, ...], ...] | None = None, outputs: tuple[Field, ...] | None = None, isEinProduct: bool = False): - self.args = args - self.isEinProduct = isEinProduct #not an einsum but an ein-product - - #inputs are the fields of the arguments by default - self.inputs = inputs if inputs is not None else tuple(tuple(f for f in arg.fields) for arg in args) - - #outputs are the union of the inputs by default, or the union of the outputs if provided - union_fields: list[Field] = outputs if outputs is not None else inputs #union fields are inputs by default - for labels in self.inputs: - for f in labels: - if f not in union_fields: - union_fields.append(f) - self.outputs = tuple(union_fields) #outputs are simply the union of the union fields - @property def children(self): # Treat only args as children in the term tree @@ -45,9 +104,6 @@ def children(self): def fields(self) -> list[Field]: return list(self.outputs) - def to_string(self) -> str: - return f"np.einsum(\"{','.join(self.inputs)}->{','.join(self.outputs)}\", {','.join(self.args)})" - class EinsumTransformer(DefaultLogicOptimizer): """ Rewrite unoptimized Logic IR (mostly MapJoin and Aggregate) into Einsum nodes. @@ -71,48 +127,63 @@ def __init__(self, ctx: LogicCompiler, verbose=True): def __call__(self, prgm: LogicNode): prgm = optimize(prgm) + prgm = self.ctx(prgm) + transformed = self.transform(prgm) return transformed + def make_einsum(self, args: tuple[LogicExpression, ...], inputs: tuple[tuple[Field, ...], ...] | None = None, outputs: tuple[Field, ...] | None = None, isEinProduct: bool = False): + #inputs are the fields of the arguments by default + construct_inputs = inputs if inputs is not None else tuple(tuple(f for f in arg.fields) for arg in args) + + #outputs are the union of the inputs by default, or the union of the outputs if provided + union_fields: list[Field] = list(outputs) if outputs is not None else list() #union fields are inputs by default + for labels in construct_inputs: + for f in labels: + if f not in union_fields: + union_fields.append(f) + construct_outputs = tuple(union_fields) + return Einsum(args=args, inputs=construct_inputs, outputs=construct_outputs, isEinProduct=isEinProduct) + def transform(self, prgm: LogicNode): def rule(node): match node: # Sum over product with harmless wrappers -> Einsum case Aggregate(Literal(operator.add), Literal(0), Relabel(MapJoin(Literal(operator.mul), args), _), idxs): - return Einsum(args=args, inputs=None, outputs=idxs) + return self.make_einsum(args=args, inputs=None, outputs=idxs) case Aggregate(Literal(operator.add), Literal(0), Reorder(MapJoin(Literal(operator.mul), args), _), idxs): - return Einsum(args=args, inputs=None, outputs=idxs) + return self.make_einsum(args=args, inputs=None, outputs=idxs) case Aggregate( Literal(operator.add), Literal(0), Relabel(Reorder(MapJoin(Literal(operator.mul), args), _), _), idxs, ): - return Einsum(args=args, inputs=None, outputs=idxs) + return self.make_einsum(args=args, inputs=None, outputs=idxs) case Aggregate( Literal(operator.add), Literal(0), Reorder(Relabel(MapJoin(Literal(operator.mul), args), _), _), idxs, ): - return Einsum(args=args, inputs=None, outputs=idxs) + return self.make_einsum(args=args, inputs=None, outputs=idxs) # Sum over already-converted Einsum (e.g., MapJoin->Einsum happened earlier) case Aggregate(Literal(operator.add), Literal(0), Einsum(args=args, inputs=_, outputs=_), idxs): - return Einsum(args=args, inputs=None, outputs=idxs) + return self.make_einsum(args=args, inputs=None, outputs=idxs) # Original core pattern matching rules # Sum over product -> Einsum(no contraction) case Aggregate(Literal(operator.add), Literal(0), MapJoin(Literal(operator.mul), args), idxs): - return Einsum(args=args, inputs=None, outputs=idxs) + return self.make_einsum(args=args, inputs=None, outputs=idxs) # Sum over product -> Einsum (no contraction) case Aggregate(Literal(operator.mul), Literal(0), MapJoin(Literal(operator.mul), args), idxs): - return Einsum(args=args, inputs=None, outputs=idxs, isEinProduct=True) + return self.make_einsum(args=args, inputs=None, outputs=idxs, isEinProduct=True) # Pure elementwise product -> Einsum (no contraction) case MapJoin(Literal(operator.mul), args): - return Einsum(args=args, inputs=None, outputs=None) + return self.make_einsum(args=args, inputs=None, outputs=None) return Rewrite(PostWalk(rule))(prgm) From dd30bcc6046cbc1487fdc571454f8c8845fe077f Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Thu, 11 Sep 2025 12:42:29 -0400 Subject: [PATCH 07/26] * Added comments to Pointwise IR * Updated Einsum to fit PointwiseIR --- src/finchlite/autoschedule/einsum.py | 59 +++++++++++++++++++--------- 1 file changed, 41 insertions(+), 18 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index d1bb4d69..0aeb5c6f 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -1,3 +1,4 @@ +from ast import Call from dataclasses import dataclass from abc import ABC from typing import Callable, Self @@ -31,7 +32,16 @@ def from_children(cls, *children: Term) -> Self: @dataclass(eq=True, frozen=True) class PointwiseAccess(PointwiseNode, TermTree): - """Tensor access like a[i, j].""" + """ + PointwiseAccess + + Tensor access like a[i, j]. + + Attributes: + tensor: The tensor to access. + idxs: The indices at which to access the tensor. + """ + tensor: LogicExpression idxs: tuple[Field, ...] # (Field('i'), Field('j')) # Children: None (leaf) @@ -46,7 +56,17 @@ def children(self): @dataclass(eq=True, frozen=True) class PointwiseOp(PointwiseNode): - """Operation like + or *.""" + """ + PointwiseOp + + Represents an operation like + or * on pointwise expressions for multiple operands. + If operation is not commutative, pointwise node must be binary, with 2 args at most. + + Attributes: + op: The function to apply e.g., operator.add + args: The arguments to the operation. + """ + op: Callable #the function to apply e.g., operator.add args: tuple[PointwiseNode, ...] # Subtrees # Children: The args @@ -80,29 +100,32 @@ def children(self): #einsum and einsum ast not part of logic IR #transform to it's own language @dataclass(eq=True, frozen=True) -class Einsum: +class Einsum(TermTree): """ - NumPy-style einsum logic node. + Einsum - - inputs: per-argument axis labels as Fields, e.g., ((i,k), (k,j)) - - output: output axis labels as Fields, e.g., (i,j) - - args: input expressions - """ + A einsum operation that maps pointwise expressions and aggregates them. - isEinProduct: bool + Attributes: + updateOp: The function to apply to the pointwise expressions (e.g. +=, f=, max=, etc...). + input_fields: The indices that are used in the pointwise expression (i.e. i, j, k). + output_fields: The indices that are used in the output (i.e. i, j). + pointwise_expr: The pointwise expression that is mapped and aggregated. + """ - inputs: tuple[tuple[Field, ...], ...] - outputs: tuple[Field, ...] - args: tuple[LogicExpression, ...] + updateOp: Callable + input_fields = tuple[Field, ...] # indicies that are used in the pointwise expression (i.e. i, j, k) + output_fields = tuple[Field, ...] # a subset of input_fields that are used in the output (i.e. i, j) + pointwise_expr: PointwiseNode # the pointwise expression that is aggregated + + @classmethod + def from_children(cls, updateOp: Callable, input_fields: tuple[Field, ...], output_fields: tuple[Field, ...], pointwise_expr: PointwiseNode) -> Self: + return cls(updateOp, input_fields, output_fields, pointwise_expr) + @property def children(self): - # Treat only args as children in the term tree - return list(self.args) - - @property - def fields(self) -> list[Field]: - return list(self.outputs) + return [self.updateOp, self.input_fields, self.output_fields, self.pointwise_expr] class EinsumTransformer(DefaultLogicOptimizer): """ From 29a95436f0557bc4c4d0790994d3796bc5ced75a Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Thu, 11 Sep 2025 21:59:44 -0400 Subject: [PATCH 08/26] * Added bare bones einsum parser (called einsum lowerer) implementation that consumes optimized logic IR --- src/finchlite/autoschedule/einsum.py | 201 +++++++-------------------- 1 file changed, 48 insertions(+), 153 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 0aeb5c6f..5ba2c01f 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -1,12 +1,10 @@ -from ast import Call from dataclasses import dataclass from abc import ABC from typing import Callable, Self -import operator -from finchlite.finch_logic import LogicExpression, MapJoin, Aggregate, LogicNode, Alias, Table, Literal, Field, Relabel, Reorder -from finchlite.autoschedule import DefaultLogicOptimizer, LogicCompiler, optimize -from finchlite.symbolic import Rewrite, PostWalk, PostOrderDFS, Term, TermTree +from finchlite.finch_logic import LogicExpression, LogicNode, Field, Plan +from finchlite.symbolic import Term, TermTree +from finchlite.autoschedule import optimize @dataclass(eq=True, frozen=True) @@ -127,155 +125,52 @@ def from_children(cls, updateOp: Callable, input_fields: tuple[Field, ...], outp def children(self): return [self.updateOp, self.input_fields, self.output_fields, self.pointwise_expr] -class EinsumTransformer(DefaultLogicOptimizer): +@dataclass(eq=True, frozen=True) +class EinsumPlan(Plan): + """ + EinsumPlan + + A plan that contains einsum operations. Basically a list of einsum operations. """ - Rewrite unoptimized Logic IR (mostly MapJoin and Aggregate) into Einsum nodes. - - Pattern handled: - - Aggregate(add, 0, MapJoin(mul, args), reduce_idxs) -> Einsum(args, outputs=reduce_idxs) - - Aggregate(add, 0, Relabel(MapJoin(mul, args), _), reduce_idxs)-> Einsum(args, outputs=reduce_idxs) - - Aggregate(add, 0, Reorder(MapJoin(mul, args), _), reduce_idxs)-> Einsum(args, outputs=reduce_idxs) - - Aggregate(add, 0, Relabel(Reorder(MapJoin(mul, args), _), _), reduce_idxs)-> Einsum(args, outputs=reduce_idxs) - - Aggregate(add, 0, Reorder(Relabel(MapJoin(mul, args), _), _), reduce_idxs)-> Einsum(args, outputs=reduce_idxs) - - Aggregate(add, 0, Einsum(...), reduce_idxs)-> Einsum(..., outputs=reduce_idxs) - - MapJoin(mul, args)-> Einsum(args) #elementwise product (no contraction) - - Aggregate(mul, 0, MapJoin(mul, args), reduce_idxs)-> Einsum(args, outputs=reduce_idxs, isEinProduct=True) #elementwise product (no contraction) + bodies: tuple[Einsum, ...] - """ + @classmethod + def from_children(cls, bodies: tuple[Einsum, ...]) -> Self: + return cls(bodies) - def __init__(self, ctx: LogicCompiler, verbose=True): - super().__init__(ctx) - self.verbose = verbose - - def __call__(self, prgm: LogicNode): - prgm = optimize(prgm) - prgm = self.ctx(prgm) - - transformed = self.transform(prgm) - - return transformed - - def make_einsum(self, args: tuple[LogicExpression, ...], inputs: tuple[tuple[Field, ...], ...] | None = None, outputs: tuple[Field, ...] | None = None, isEinProduct: bool = False): - #inputs are the fields of the arguments by default - construct_inputs = inputs if inputs is not None else tuple(tuple(f for f in arg.fields) for arg in args) - - #outputs are the union of the inputs by default, or the union of the outputs if provided - union_fields: list[Field] = list(outputs) if outputs is not None else list() #union fields are inputs by default - for labels in construct_inputs: - for f in labels: - if f not in union_fields: - union_fields.append(f) - construct_outputs = tuple(union_fields) - return Einsum(args=args, inputs=construct_inputs, outputs=construct_outputs, isEinProduct=isEinProduct) - - def transform(self, prgm: LogicNode): - def rule(node): - match node: - - # Sum over product with harmless wrappers -> Einsum - case Aggregate(Literal(operator.add), Literal(0), Relabel(MapJoin(Literal(operator.mul), args), _), idxs): - return self.make_einsum(args=args, inputs=None, outputs=idxs) - case Aggregate(Literal(operator.add), Literal(0), Reorder(MapJoin(Literal(operator.mul), args), _), idxs): - return self.make_einsum(args=args, inputs=None, outputs=idxs) - case Aggregate( - Literal(operator.add), - Literal(0), - Relabel(Reorder(MapJoin(Literal(operator.mul), args), _), _), - idxs, - ): - return self.make_einsum(args=args, inputs=None, outputs=idxs) - case Aggregate( - Literal(operator.add), - Literal(0), - Reorder(Relabel(MapJoin(Literal(operator.mul), args), _), _), - idxs, - ): - return self.make_einsum(args=args, inputs=None, outputs=idxs) - - # Sum over already-converted Einsum (e.g., MapJoin->Einsum happened earlier) - case Aggregate(Literal(operator.add), Literal(0), Einsum(args=args, inputs=_, outputs=_), idxs): - return self.make_einsum(args=args, inputs=None, outputs=idxs) - - # Original core pattern matching rules - # Sum over product -> Einsum(no contraction) - case Aggregate(Literal(operator.add), Literal(0), MapJoin(Literal(operator.mul), args), idxs): - return self.make_einsum(args=args, inputs=None, outputs=idxs) - # Sum over product -> Einsum (no contraction) - case Aggregate(Literal(operator.mul), Literal(0), MapJoin(Literal(operator.mul), args), idxs): - return self.make_einsum(args=args, inputs=None, outputs=idxs, isEinProduct=True) - # Pure elementwise product -> Einsum (no contraction) - case MapJoin(Literal(operator.mul), args): - return self.make_einsum(args=args, inputs=None, outputs=None) - - return Rewrite(PostWalk(rule))(prgm) - - -class PrintingLogicOptimizer(DefaultLogicOptimizer): - """Custom optimizer that prints MapJoin and Aggregate operations""" - - def __init__(self, ctx: LogicCompiler, verbose=True): - super().__init__(ctx) - self.verbose = verbose - self.operation_count = {"MapJoin": 0, "Aggregate": 0} - - def __call__(self, prgm: LogicNode): - # First optimize the program - prgm = optimize(prgm) - - # Then traverse and print all MapJoin/Aggregate operations - if self.verbose: - print("\n=== Finch Logic IR Operations ===") - self._print_operations(prgm) - print(f"\nTotal MapJoins: {self.operation_count['MapJoin']}") - print(f"Total Aggregates: {self.operation_count['Aggregate']}") - print("================================\n") - - # Continue with compilation - return self.ctx(prgm) - - def _print_operations(self, node): - """Traverse the Logic IR and print MapJoin/Aggregate operations""" - for n in PostOrderDFS(node): - match n: - case MapJoin(op, args): - self.operation_count["MapJoin"] += 1 - print(f"\nMapJoin #{self.operation_count['MapJoin']}:") - print(f" Operation: {self._format_op(op)}") - print(f" Args: {self._format_args(args)}") - print(f" Fields: {n.fields}") - - case Aggregate(op, init, arg, idxs): - self.operation_count["Aggregate"] += 1 - print(f"\nAggregate #{self.operation_count['Aggregate']}:") - print(f" Operation: {self._format_op(op)}") - print(f" Init: {self._format_literal(init)}") - print(f" Reduce dims: {idxs}") - print(f" Input fields: {arg.fields if hasattr(arg, 'fields') else 'N/A'}") - print(f" Output fields: {n.fields}") - - def _format_op(self, op): - """Format operation for printing""" - if isinstance(op, Literal): - if hasattr(op.val, '__name__'): - return op.val.__name__ - return str(op.val) - return str(op) - - def _format_literal(self, lit): - """Format literal for printing""" - if isinstance(lit, Literal): - return str(lit.val) - return str(lit) - - def _format_args(self, args): - """Format arguments for printing""" - formatted = [] - for arg in args: - if isinstance(arg, Alias): - formatted.append(f"Alias({arg.name})") - elif isinstance(arg, Table): - formatted.append(f"Table(...)") - else: - formatted.append(type(arg).__name__) - return formatted \ No newline at end of file + @property + def children(self) -> tuple[Einsum, ...]: + return self.bodies + +def make_einsum_plan(bodies: tuple[Einsum | EinsumPlan, ...]) -> EinsumPlan: + """Flatten nested EinsumPlans so the resulting tuple contains only Einsum nodes.""" + flat: list[Einsum] = [] + for body in bodies: + if isinstance(body, EinsumPlan): + flat.extend(body.children) + else: + flat.append(body) + return EinsumPlan(tuple(flat)) + +class EinsumLowerer: + def __call__(self, ex: LogicNode) -> EinsumPlan: + match ex: + case Plan(bodies): + return make_einsum_plan(tuple(self(body) for body in bodies)) + case _: + raise Exception(f"Unrecognized logic: {ex}") + +class EinsumCompiler: + def __init__(self): + self.el = EinsumLowerer() + + def __call__(self, prgm: Plan): + return self.el(prgm) + + +def einsum_scheduler(plan: Plan): + optimized_prgm = optimize(plan) + + interpreter = EinsumCompiler() + return interpreter(optimized_prgm) \ No newline at end of file From 2f7f9517bf36f2cc9dde17ebc0f8b45297864524 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 12 Sep 2025 00:06:00 -0400 Subject: [PATCH 09/26] * Started implementing einsum lowerer. * Added support for recursive descent parsing of MapJoins --- src/finchlite/autoschedule/einsum.py | 75 ++++++++++++++++++++-------- 1 file changed, 55 insertions(+), 20 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 5ba2c01f..7e637b92 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -2,9 +2,11 @@ from abc import ABC from typing import Callable, Self -from finchlite.finch_logic import LogicExpression, LogicNode, Field, Plan +from finchlite.finch_logic import LogicNode, Field, Plan, Query, Alias, Literal, Relabel +from finchlite.finch_logic.nodes import MapJoin from finchlite.symbolic import Term, TermTree from finchlite.autoschedule import optimize +from finchlite.algebra import is_commutative, identity @dataclass(eq=True, frozen=True) @@ -40,17 +42,17 @@ class PointwiseAccess(PointwiseNode, TermTree): idxs: The indices at which to access the tensor. """ - tensor: LogicExpression + alias: str idxs: tuple[Field, ...] # (Field('i'), Field('j')) # Children: None (leaf) @classmethod - def from_children(cls, tensor: LogicExpression, idxs: tuple[Field, ...]) -> Self: - return cls(tensor, idxs) + def from_children(cls, alias: str, idxs: tuple[Field, ...]) -> Self: + return cls(alias, idxs) @property def children(self): - return [self.tensor, *self.idxs] + return [self.alias, *self.idxs] @dataclass(eq=True, frozen=True) class PointwiseOp(PointwiseNode): @@ -61,7 +63,7 @@ class PointwiseOp(PointwiseNode): If operation is not commutative, pointwise node must be binary, with 2 args at most. Attributes: - op: The function to apply e.g., operator.add + op: The function to apply e.g., operator.add, operator.mul, operator.subtract, operator.div, etc... Must be a callable. args: The arguments to the operation. """ @@ -87,13 +89,12 @@ class PointwiseLiteral(PointwiseNode): val: float - @classmethod - def from_children(cls, val: float) -> Self: - return cls(val) + def __hash__(self): + return hash(self.val) + + def __eq__(self, other): + return isinstance(other, PointwiseLiteral) and self.val == other.val - @property - def children(self): - return [self.val] #einsum and einsum ast not part of logic IR #transform to it's own language @@ -111,19 +112,20 @@ class Einsum(TermTree): pointwise_expr: The pointwise expression that is mapped and aggregated. """ - updateOp: Callable + updateOp: Callable - input_fields = tuple[Field, ...] # indicies that are used in the pointwise expression (i.e. i, j, k) - output_fields = tuple[Field, ...] # a subset of input_fields that are used in the output (i.e. i, j) - pointwise_expr: PointwiseNode # the pointwise expression that is aggregated + input_fields: tuple[Field, ...] + output_fields: tuple[Field, ...] + pointwise_expr: PointwiseNode + output_alias: str | None @classmethod - def from_children(cls, updateOp: Callable, input_fields: tuple[Field, ...], output_fields: tuple[Field, ...], pointwise_expr: PointwiseNode) -> Self: - return cls(updateOp, input_fields, output_fields, pointwise_expr) + def from_children(cls, output_alias: str | None, updateOp: Callable, input_fields: tuple[Field, ...], output_fields: tuple[Field, ...], pointwise_expr: PointwiseNode) -> Self: + return cls(output_alias, updateOp, input_fields, output_fields, pointwise_expr) @property def children(self): - return [self.updateOp, self.input_fields, self.output_fields, self.pointwise_expr] + return [self.output_alias, self.updateOp, self.input_fields, self.output_fields, self.pointwise_expr] @dataclass(eq=True, frozen=True) class EinsumPlan(Plan): @@ -154,10 +156,43 @@ def make_einsum_plan(bodies: tuple[Einsum | EinsumPlan, ...]) -> EinsumPlan: return EinsumPlan(tuple(flat)) class EinsumLowerer: - def __call__(self, ex: LogicNode) -> EinsumPlan: + def __call__(self, ex: LogicNode) -> EinsumPlan | Einsum: match ex: case Plan(bodies): return make_einsum_plan(tuple(self(body) for body in bodies)) + case Query(Alias(name), rhs): + rhsEinsum = self(rhs) + return Einsum(output_alias=name, updateOp=rhsEinsum.updateOp, input_fields=rhsEinsum.input_fields, output_fields=rhsEinsum.output_fields, pointwise_expr=rhsEinsum.pointwise_expr) + case MapJoin(Literal(operation), args): + args = [self.lower_to_pointwise(arg) for arg in args] + pointwise_expr = self.lower_to_pointwise_op(operation, args) + return Einsum(output_alias=None, updateOp=identity, input_fields=ex.fields, output_fields=ex.fields, pointwise_expr=pointwise_expr) + case _: + raise Exception(f"Unrecognized logic: {ex}") + + def lower_to_pointwise_op(self, operation: Callable, args: tuple[PointwiseNode, ...]) -> PointwiseOp: + # if operation is commutative, we simply pass all the args to the pointwise op since order of args does not matter + if is_commutative(operation): + return PointwiseOp(operation, args) + + # combine args from left to right (i.e a / b / c -> (a / b) / c) + assert len(args) > 1 + result = PointwiseOp(operation, args[0], args[1]) + for arg in args[2:]: + result = PointwiseOp(operation, result, arg) + return result + + # lowers nested mapjoin logic IR nodes into a single pointwise expression + def lower_to_pointwise(self, ex: LogicNode) -> PointwiseNode: + match ex: + case MapJoin(Literal(operation), args): + args = [self.lower_to_pointwise(arg) for arg in args] + return self.lower_to_pointwise_op(operation, args) + + case Relabel(Alias(name), idxs): # relable is really just a glorified pointwise access + return PointwiseAccess(alias=name, idxs=idxs) + case Literal(value): + return PointwiseLiteral(val=value) case _: raise Exception(f"Unrecognized logic: {ex}") From f57a8e41feae0d523d26b6f8345799e753457f0e Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 12 Sep 2025 08:24:42 -0400 Subject: [PATCH 10/26] * Added support for handling Reorder statements --- src/finchlite/autoschedule/einsum.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 7e637b92..3d7491ad 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -3,7 +3,7 @@ from typing import Callable, Self from finchlite.finch_logic import LogicNode, Field, Plan, Query, Alias, Literal, Relabel -from finchlite.finch_logic.nodes import MapJoin +from finchlite.finch_logic.nodes import Aggregate, MapJoin, Produces, Reorder from finchlite.symbolic import Term, TermTree from finchlite.autoschedule import optimize from finchlite.algebra import is_commutative, identity @@ -127,6 +127,15 @@ def from_children(cls, output_alias: str | None, updateOp: Callable, input_field def children(self): return [self.output_alias, self.updateOp, self.input_fields, self.output_fields, self.pointwise_expr] + @classmethod + def rename(self, new_alias: str): + return Einsum(self.updateOp, self.input_fields, self.output_fields, self.pointwise_expr, new_alias) + + @classmethod + def reorder(self, idxs: tuple[int, ...]): + new_input_fields = tuple(self.input_fields[i] for i in idxs) + return Einsum(self.updateOp, new_input_fields, self.output_fields, self.pointwise_expr, self.output_alias) + @dataclass(eq=True, frozen=True) class EinsumPlan(Plan): """ @@ -162,11 +171,24 @@ def __call__(self, ex: LogicNode) -> EinsumPlan | Einsum: return make_einsum_plan(tuple(self(body) for body in bodies)) case Query(Alias(name), rhs): rhsEinsum = self(rhs) - return Einsum(output_alias=name, updateOp=rhsEinsum.updateOp, input_fields=rhsEinsum.input_fields, output_fields=rhsEinsum.output_fields, pointwise_expr=rhsEinsum.pointwise_expr) + if isinstance(rhsEinsum, EinsumPlan): + raise Exception("Cannot alias an einsum plan."); + return rhsEinsum.rename(name) case MapJoin(Literal(operation), args): args = [self.lower_to_pointwise(arg) for arg in args] pointwise_expr = self.lower_to_pointwise_op(operation, args) return Einsum(output_alias=None, updateOp=identity, input_fields=ex.fields, output_fields=ex.fields, pointwise_expr=pointwise_expr) + case Reorder(arg, idxs): + argEinsum = self(arg) + if isinstance(argEinsum, EinsumPlan): + raise Exception("Cannot reorder an einsum plan."); + return argEinsum.reorder(idxs) + + case Produces(arg): + argEinsum = self(arg) + if isinstance(argEinsum, EinsumPlan): + raise Exception("Cannot produce an einsum plan."); + return argEinsum.rename("final_output") case _: raise Exception(f"Unrecognized logic: {ex}") @@ -188,7 +210,6 @@ def lower_to_pointwise(self, ex: LogicNode) -> PointwiseNode: case MapJoin(Literal(operation), args): args = [self.lower_to_pointwise(arg) for arg in args] return self.lower_to_pointwise_op(operation, args) - case Relabel(Alias(name), idxs): # relable is really just a glorified pointwise access return PointwiseAccess(alias=name, idxs=idxs) case Literal(value): From f829d0431b8dc74cc0b2e1e390bfdceaf3dd04b5 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 12 Sep 2025 11:54:05 -0400 Subject: [PATCH 11/26] * Restructured the recursive descent parsing to handle top level plan logic ir statements seperate from value logic ir statements * Seperates into compile_plan and lower_to_pointwise_op * Added support for handling aggregates (use builtin init_value only for each reduce function, non standard init values aren't supported) --- src/finchlite/autoschedule/einsum.py | 112 ++++++++++++++++----------- 1 file changed, 68 insertions(+), 44 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 3d7491ad..38f7853d 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -6,8 +6,7 @@ from finchlite.finch_logic.nodes import Aggregate, MapJoin, Produces, Reorder from finchlite.symbolic import Term, TermTree from finchlite.autoschedule import optimize -from finchlite.algebra import is_commutative, identity - +from finchlite.algebra import is_commutative, overwrite, init_value @dataclass(eq=True, frozen=True) class PointwiseNode(Term, ABC): @@ -112,7 +111,7 @@ class Einsum(TermTree): pointwise_expr: The pointwise expression that is mapped and aggregated. """ - updateOp: Callable + reduceOp: Callable #technically a reduce operation, much akin to the one in aggregate input_fields: tuple[Field, ...] output_fields: tuple[Field, ...] @@ -125,16 +124,15 @@ def from_children(cls, output_alias: str | None, updateOp: Callable, input_field @property def children(self): - return [self.output_alias, self.updateOp, self.input_fields, self.output_fields, self.pointwise_expr] + return [self.output_alias, self.reduceOp, self.input_fields, self.output_fields, self.pointwise_expr] @classmethod def rename(self, new_alias: str): - return Einsum(self.updateOp, self.input_fields, self.output_fields, self.pointwise_expr, new_alias) + return Einsum(self.reduceOp, self.input_fields, self.output_fields, self.pointwise_expr, new_alias) @classmethod - def reorder(self, idxs: tuple[int, ...]): - new_input_fields = tuple(self.input_fields[i] for i in idxs) - return Einsum(self.updateOp, new_input_fields, self.output_fields, self.pointwise_expr, self.output_alias) + def reorder(self, idxs: tuple[Field, ...]): + return Einsum(self.reduceOp, idxs, self.output_fields, self.pointwise_expr, self.output_alias) @dataclass(eq=True, frozen=True) class EinsumPlan(Plan): @@ -145,56 +143,79 @@ class EinsumPlan(Plan): """ bodies: tuple[Einsum, ...] + returnValue: Einsum | None @classmethod - def from_children(cls, bodies: tuple[Einsum, ...]) -> Self: - return cls(bodies) + def from_children(cls, bodies: tuple[Einsum, ...], returnValue: Einsum | None) -> Self: + return cls(bodies, returnValue) @property def children(self) -> tuple[Einsum, ...]: - return self.bodies - -def make_einsum_plan(bodies: tuple[Einsum | EinsumPlan, ...]) -> EinsumPlan: - """Flatten nested EinsumPlans so the resulting tuple contains only Einsum nodes.""" - flat: list[Einsum] = [] - for body in bodies: - if isinstance(body, EinsumPlan): - flat.extend(body.children) - else: - flat.append(body) - return EinsumPlan(tuple(flat)) + return [*self.bodies, self.returnValue] class EinsumLowerer: - def __call__(self, ex: LogicNode) -> EinsumPlan | Einsum: + alias_counter: int = 0 + + def __call__(self, prgm: Plan) -> EinsumPlan: + return self.compile_plan(prgm) + + def get_next_alias(self) -> str: + self.alias_counter += 1 + return f"einsum_{self.alias_counter}" + + def compile_plan(self, plan: Plan) -> EinsumPlan: + einsums = [] + returnValue = None + + for body in plan.bodies: + match body: + case Plan(_): + plan = self.compile_plan(body) + if plan.returnValue is not None: + raise Exception("Plans with return values are not statements, but rather are expressions.") + einsums.extend(plan.bodies) + case Query(Alias(name), rhs): + einsums.append(self.lower_to_einsum(rhs, einsums).rename(name)) + case Produces(arg): + if returnValue is not None: + raise Exception("Only one return value is supported.") + returnValue = self.lower_to_einsum(arg, einsums) + case _: + einsums.append(self.lower_to_einsum(body, einsums).rename(self.get_next_alias())) + + return EinsumPlan(tuple(einsums), returnValue) + + def lower_to_einsum(self, ex: LogicNode, einsums: list[Einsum]) -> Einsum: match ex: - case Plan(bodies): - return make_einsum_plan(tuple(self(body) for body in bodies)) - case Query(Alias(name), rhs): - rhsEinsum = self(rhs) - if isinstance(rhsEinsum, EinsumPlan): - raise Exception("Cannot alias an einsum plan."); - return rhsEinsum.rename(name) + case Plan(_): + plan = self.compile_plan(ex) + einsums.extend(plan.bodies) + return plan.returnValue case MapJoin(Literal(operation), args): - args = [self.lower_to_pointwise(arg) for arg in args] + args = [self.lower_to_pointwise(arg, einsums) for arg in args] pointwise_expr = self.lower_to_pointwise_op(operation, args) - return Einsum(output_alias=None, updateOp=identity, input_fields=ex.fields, output_fields=ex.fields, pointwise_expr=pointwise_expr) + return Einsum(reduceOp=overwrite, input_fields=ex.fields, output_fields=ex.fields, pointwise_expr=pointwise_expr, output_alias=None) case Reorder(arg, idxs): - argEinsum = self(arg) - if isinstance(argEinsum, EinsumPlan): - raise Exception("Cannot reorder an einsum plan."); - return argEinsum.reorder(idxs) - - case Produces(arg): - argEinsum = self(arg) - if isinstance(argEinsum, EinsumPlan): - raise Exception("Cannot produce an einsum plan."); - return argEinsum.rename("final_output") + return self.lower_to_einsum(arg, einsums).reorder(idxs) + case Aggregate(operation, init, arg, idxs): + if init != init_value(operation, type(init)): + raise Exception(f"Init value {init} is not the default value for operation {operation} of type {type(init)}. Non standard init values are not supported.") + pointwise_expr = self.lower_to_pointwise(arg, einsums) + return Einsum(operation, arg.fields, ex.fields, pointwise_expr, self.get_next_alias(), output_alias=None) case _: raise Exception(f"Unrecognized logic: {ex}") def lower_to_pointwise_op(self, operation: Callable, args: tuple[PointwiseNode, ...]) -> PointwiseOp: # if operation is commutative, we simply pass all the args to the pointwise op since order of args does not matter if is_commutative(operation): + args = [] # flatten the args + for arg in args: + match arg: + case PointwiseOp(op2, _) if op2 == operation: + args.extend(arg.args) + case _: + args.append(arg) + return PointwiseOp(operation, args) # combine args from left to right (i.e a / b / c -> (a / b) / c) @@ -205,15 +226,19 @@ def lower_to_pointwise_op(self, operation: Callable, args: tuple[PointwiseNode, return result # lowers nested mapjoin logic IR nodes into a single pointwise expression - def lower_to_pointwise(self, ex: LogicNode) -> PointwiseNode: + def lower_to_pointwise(self, ex: LogicNode, einsums: list[Einsum]) -> PointwiseNode: match ex: case MapJoin(Literal(operation), args): - args = [self.lower_to_pointwise(arg) for arg in args] + args = [self.lower_to_pointwise(arg, einsums) for arg in args] return self.lower_to_pointwise_op(operation, args) case Relabel(Alias(name), idxs): # relable is really just a glorified pointwise access return PointwiseAccess(alias=name, idxs=idxs) case Literal(value): return PointwiseLiteral(val=value) + case Aggregate(_, _, _, _): # aggregate has to be computed seperatley as it's own einsum + aggregate_einsum_alias = self.get_next_alias() + einsums.append(self.lower_to_einsum(ex, einsums).rename(aggregate_einsum_alias)) + return PointwiseAccess(alias=aggregate_einsum_alias, idxs=tuple(ex.fields)) case _: raise Exception(f"Unrecognized logic: {ex}") @@ -224,7 +249,6 @@ def __init__(self): def __call__(self, prgm: Plan): return self.el(prgm) - def einsum_scheduler(plan: Plan): optimized_prgm = optimize(plan) From 3b3ba24f8d6d6d2809c56ff82a2df3bafd9d1bda Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 12 Sep 2025 12:20:10 -0400 Subject: [PATCH 12/26] * Added support for printing einsums, pointwise operations, and einsum plans --- src/finchlite/autoschedule/einsum.py | 71 +++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 38f7853d..f302aeff 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from abc import ABC +import operator from typing import Callable, Self from finchlite.finch_logic import LogicNode, Field, Plan, Query, Alias, Literal, Relabel @@ -7,6 +8,7 @@ from finchlite.symbolic import Term, TermTree from finchlite.autoschedule import optimize from finchlite.algebra import is_commutative, overwrite, init_value +from finchlite.symbolic import Context @dataclass(eq=True, frozen=True) class PointwiseNode(Term, ABC): @@ -29,6 +31,10 @@ def make_term(cls, head, *children: Term) -> Self: def from_children(cls, *children: Term) -> Self: return cls(*children) + def __str__(self): + ctx = EinsumPrinterContext() + return ctx.print_pointwise_expr(self) + @dataclass(eq=True, frozen=True) class PointwiseAccess(PointwiseNode, TermTree): """ @@ -134,6 +140,10 @@ def rename(self, new_alias: str): def reorder(self, idxs: tuple[Field, ...]): return Einsum(self.reduceOp, idxs, self.output_fields, self.pointwise_expr, self.output_alias) + def __str__(self): + ctx = EinsumPrinterContext() + return ctx.print_einsum(self) + @dataclass(eq=True, frozen=True) class EinsumPlan(Plan): """ @@ -153,6 +163,10 @@ def from_children(cls, bodies: tuple[Einsum, ...], returnValue: Einsum | None) - def children(self) -> tuple[Einsum, ...]: return [*self.bodies, self.returnValue] + def __str__(self): + ctx = EinsumPrinterContext() + return ctx(self) + class EinsumLowerer: alias_counter: int = 0 @@ -253,4 +267,59 @@ def einsum_scheduler(plan: Plan): optimized_prgm = optimize(plan) interpreter = EinsumCompiler() - return interpreter(optimized_prgm) \ No newline at end of file + return interpreter(optimized_prgm) + +class EinsumPrinterContext(Context): + def print_indicies(self, idxs: tuple[Field, ...]): + return ", ".join([str(idx) for idx in idxs]) + + def print_reducer(self, reducer: Callable): + str_map = { + overwrite: "=", + operator.add: "+=", + operator.sub: "-=", + operator.mul: "*=", + operator.div: "/=", + operator.truediv: "/=", + operator.mod: "%=", + operator.pow: "**=", + operator.and_: "&=", + operator.or_: "|=", + operator.xor: "^=", + operator.floordiv: "//=", + operator.mod: "%=", + operator.pow: "**=", + } + return str_map[reducer] + + def print_pointwise_op(self, op: Callable): + str_map = { + operator.add: "+", + operator.sub: "-", + operator.mul: "*", + operator.div: "/", + operator.truediv: "/", + operator.mod: "%", + operator.pow: "**", + } + return str_map[op] + + def print_pointwise_expr(self, pointwise_expr: PointwiseNode): + match pointwise_expr: + case PointwiseAccess(alias, idxs): + return f"{alias}[{self.print_indicies(idxs)}]" + case PointwiseOp(op, args): + return f"{self.print_pointwise_op(op)}({', '.join(self.print_pointwise_expr(arg) for arg in args)})" + case PointwiseLiteral(val): + return str(val) + + def print_einsum(self, einsum: Einsum): + return f"{einsum.output_alias}[{self.print_indicies(einsum.output_fields)}] {self.print_reducer(einsum.reduceOp)} {self.print_pointwise_expr(einsum.pointwise_expr)}" + + def print_einsum_plan(self, einsum_plan: EinsumPlan): + if einsum_plan.returnValue is None: + return "\n".join([self.print_einsum(einsum) for einsum in einsum_plan.bodies]) + return f"{"\n".join([self.print_einsum(einsum) for einsum in einsum_plan.bodies])}\nreturn {self.print_einsum(einsum_plan.returnValue)}" + + def __call__(self, prgm: EinsumPlan): + return self.print_einsum_plan(prgm) \ No newline at end of file From fd377ccad5ffbad0aa54c81bd59ea6db4b72c574 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 12 Sep 2025 12:33:47 -0400 Subject: [PATCH 13/26] * Added support for is_commutative property for operations in algebra.py --- src/finchlite/algebra/__init__.py | 2 ++ src/finchlite/algebra/algebra.py | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/src/finchlite/algebra/__init__.py b/src/finchlite/algebra/__init__.py index 817fe41b..ec818027 100644 --- a/src/finchlite/algebra/__init__.py +++ b/src/finchlite/algebra/__init__.py @@ -4,6 +4,7 @@ init_value, is_annihilator, is_associative, + is_commutative, is_distributive, is_identity, promote_type, @@ -46,6 +47,7 @@ "is_annihilator", "is_associative", "is_distributive", + "is_commutative", "is_identity", "overwrite", "promote_max", diff --git a/src/finchlite/algebra/algebra.py b/src/finchlite/algebra/algebra.py index 3f6ee89c..ab0c62ad 100644 --- a/src/finchlite/algebra/algebra.py +++ b/src/finchlite/algebra/algebra.py @@ -345,6 +345,18 @@ def is_associative(op: Any) -> bool: register_property(np.logical_or, "__call__", "is_associative", lambda op: True) register_property(np.logical_xor, "__call__", "is_associative", lambda op: True) +# Commutative properties + +for op in (operator.add, operator.mul, operator.and_, operator.or_, operator.xor, + np.logical_and, np.logical_or, np.logical_xor): + register_property(op, "__call__", "is_commutative", lambda op: True) + +for op in (operator.sub, operator.truediv, operator.floordiv, operator.pow, + operator.lshift, operator.rshift): + register_property(op, "__call__", "is_commutative", lambda op: False) + +def is_commutative(op: Any) -> bool: + return query_property(op, "__call__", "is_commutative") def is_identity(op: Any, val: Any) -> bool: """ From 73af234999eb0313b241f66a4f5d8890981b207d Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 12 Sep 2025 17:42:46 -0400 Subject: [PATCH 14/26] * Fixed many, many bugs with EinsumLowerer --- src/finchlite/autoschedule/__init__.py | 7 +- src/finchlite/autoschedule/einsum.py | 129 +++++++++++++++---------- 2 files changed, 84 insertions(+), 52 deletions(-) diff --git a/src/finchlite/autoschedule/__init__.py b/src/finchlite/autoschedule/__init__.py index 5e692653..ab3eb3b3 100644 --- a/src/finchlite/autoschedule/__init__.py +++ b/src/finchlite/autoschedule/__init__.py @@ -40,8 +40,8 @@ ) from .einsum import ( Einsum, - EinsumTransformer, - PrintingLogicOptimizer + EinsumPlan, + einsum_scheduler ) __all__ = [ @@ -49,7 +49,8 @@ "Alias", "DefaultLogicOptimizer", "Einsum", - "EinsumTransformer", + "EinsumPlan", + "einsum_scheduler", "Field", "Literal", "LogicCompiler", diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index f302aeff..f14af797 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -1,10 +1,11 @@ from dataclasses import dataclass from abc import ABC import operator +import traceback from typing import Callable, Self from finchlite.finch_logic import LogicNode, Field, Plan, Query, Alias, Literal, Relabel -from finchlite.finch_logic.nodes import Aggregate, MapJoin, Produces, Reorder +from finchlite.finch_logic.nodes import Aggregate, MapJoin, Produces, Reorder, Table from finchlite.symbolic import Term, TermTree from finchlite.autoschedule import optimize from finchlite.algebra import is_commutative, overwrite, init_value @@ -132,11 +133,9 @@ def from_children(cls, output_alias: str | None, updateOp: Callable, input_field def children(self): return [self.output_alias, self.reduceOp, self.input_fields, self.output_fields, self.pointwise_expr] - @classmethod def rename(self, new_alias: str): return Einsum(self.reduceOp, self.input_fields, self.output_fields, self.pointwise_expr, new_alias) - @classmethod def reorder(self, idxs: tuple[Field, ...]): return Einsum(self.reduceOp, idxs, self.output_fields, self.pointwise_expr, self.output_alias) @@ -152,16 +151,16 @@ class EinsumPlan(Plan): A plan that contains einsum operations. Basically a list of einsum operations. """ - bodies: tuple[Einsum, ...] - returnValue: Einsum | None + bodies: tuple[Einsum, ...] = () + returnValues: tuple[Einsum | str] = () @classmethod - def from_children(cls, bodies: tuple[Einsum, ...], returnValue: Einsum | None) -> Self: + def from_children(cls, bodies: tuple[Einsum, ...], returnValue: tuple[Einsum | str]) -> Self: return cls(bodies, returnValue) @property - def children(self) -> tuple[Einsum, ...]: - return [*self.bodies, self.returnValue] + def children(self): + return [*self.bodies, self.returnValues] def __str__(self): ctx = EinsumPrinterContext() @@ -170,67 +169,89 @@ def __str__(self): class EinsumLowerer: alias_counter: int = 0 - def __call__(self, prgm: Plan) -> EinsumPlan: - return self.compile_plan(prgm) + def __call__(self, prgm: Plan, parameters: dict[str, Table], definitions: dict[str, Einsum]) -> EinsumPlan: + return self.compile_plan(prgm, parameters, definitions) def get_next_alias(self) -> str: self.alias_counter += 1 return f"einsum_{self.alias_counter}" - def compile_plan(self, plan: Plan) -> EinsumPlan: + def rename_einsum(self, einsum: Einsum, new_alias: str, definitions: dict[str, Einsum]) -> Einsum: + definitions[new_alias] = einsum + return einsum.rename(new_alias) + + def compile_plan(self, plan: Plan, parameters: dict[str, Table], definitions: dict[str, Einsum]) -> EinsumPlan: einsums = [] - returnValue = None + returnValue = [] for body in plan.bodies: match body: case Plan(_): - plan = self.compile_plan(body) - if plan.returnValue is not None: - raise Exception("Plans with return values are not statements, but rather are expressions.") - einsums.extend(plan.bodies) + einsum_plan = self.compile_plan(body, parameters, definitions) + einsums.extend(einsum_plan.bodies) + + if einsum_plan.returnValues: + if returnValue: + raise Exception("Cannot invoke return more than once.") + returnValue = einsum_plan.returnValues + case Query(Alias(name), Table(_, _)): + parameters[name] = body.rhs case Query(Alias(name), rhs): - einsums.append(self.lower_to_einsum(rhs, einsums).rename(name)) - case Produces(arg): - if returnValue is not None: - raise Exception("Only one return value is supported.") - returnValue = self.lower_to_einsum(arg, einsums) + einsums.append(self.rename_einsum(self.lower_to_einsum(rhs, einsums, parameters, definitions), name, definitions)) + case Produces(args): + if returnValue: + raise Exception("Cannot invoke return more than once.") + for arg in args: + returnValue.append(arg.name if isinstance(arg, Alias) else self.lower_to_einsum(arg, einsums, parameters, definitions)) case _: - einsums.append(self.lower_to_einsum(body, einsums).rename(self.get_next_alias())) + einsums.append(self.rename_einsum(self.lower_to_einsum(body, einsums, parameters, definitions), self.get_next_alias(), definitions)) - return EinsumPlan(tuple(einsums), returnValue) + return EinsumPlan(tuple(einsums), tuple(returnValue)) - def lower_to_einsum(self, ex: LogicNode, einsums: list[Einsum]) -> Einsum: + def lower_to_einsum(self, ex: LogicNode, einsums: list[Einsum], parameters: dict[str, Table], definitions: dict[str, Einsum]) -> Einsum: match ex: case Plan(_): - plan = self.compile_plan(ex) + plan = self.compile_plan(ex, parameters, definitions) einsums.extend(plan.bodies) - return plan.returnValue + + if plan.returnValues: + raise Exception("Plans with no return value are not statements, but rather are expressions.") + + if len(plan.returnValues) > 1: + raise Exception("Only one return value is supported.") + + if isinstance(plan.returnValues[0], str): + returned_alias = plan.returnValues[0] + returned_einsum = definitions[returned_alias] + return PointwiseAccess(alias=returned_alias, idxs=returned_einsum.output_fields) + + return plan.returnValues[0] case MapJoin(Literal(operation), args): - args = [self.lower_to_pointwise(arg, einsums) for arg in args] + args = [self.lower_to_pointwise(arg, einsums, parameters, definitions) for arg in args] pointwise_expr = self.lower_to_pointwise_op(operation, args) return Einsum(reduceOp=overwrite, input_fields=ex.fields, output_fields=ex.fields, pointwise_expr=pointwise_expr, output_alias=None) case Reorder(arg, idxs): - return self.lower_to_einsum(arg, einsums).reorder(idxs) - case Aggregate(operation, init, arg, idxs): + return self.lower_to_einsum(arg, einsums, parameters, definitions).reorder(idxs) + case Aggregate(Literal(operation), Literal(init), arg, idxs): if init != init_value(operation, type(init)): raise Exception(f"Init value {init} is not the default value for operation {operation} of type {type(init)}. Non standard init values are not supported.") - pointwise_expr = self.lower_to_pointwise(arg, einsums) - return Einsum(operation, arg.fields, ex.fields, pointwise_expr, self.get_next_alias(), output_alias=None) + pointwise_expr = self.lower_to_pointwise(arg, einsums, parameters, definitions) + return Einsum(operation, arg.fields, ex.fields, pointwise_expr, self.get_next_alias()) case _: raise Exception(f"Unrecognized logic: {ex}") def lower_to_pointwise_op(self, operation: Callable, args: tuple[PointwiseNode, ...]) -> PointwiseOp: # if operation is commutative, we simply pass all the args to the pointwise op since order of args does not matter if is_commutative(operation): - args = [] # flatten the args + ret_args = [] # flatten the args for arg in args: match arg: case PointwiseOp(op2, _) if op2 == operation: - args.extend(arg.args) + ret_args.extend(arg.args) case _: - args.append(arg) + ret_args.append(arg) - return PointwiseOp(operation, args) + return PointwiseOp(operation, ret_args) # combine args from left to right (i.e a / b / c -> (a / b) / c) assert len(args) > 1 @@ -240,10 +261,12 @@ def lower_to_pointwise_op(self, operation: Callable, args: tuple[PointwiseNode, return result # lowers nested mapjoin logic IR nodes into a single pointwise expression - def lower_to_pointwise(self, ex: LogicNode, einsums: list[Einsum]) -> PointwiseNode: + def lower_to_pointwise(self, ex: LogicNode, einsums: list[Einsum], parameters: dict[str, Table], definitions: dict[str, Einsum]) -> PointwiseNode: match ex: + case Reorder(arg, idxs): + return self.lower_to_pointwise(arg, einsums, parameters, definitions) case MapJoin(Literal(operation), args): - args = [self.lower_to_pointwise(arg, einsums) for arg in args] + args = [self.lower_to_pointwise(arg, einsums, parameters, definitions) for arg in args] return self.lower_to_pointwise_op(operation, args) case Relabel(Alias(name), idxs): # relable is really just a glorified pointwise access return PointwiseAccess(alias=name, idxs=idxs) @@ -251,7 +274,7 @@ def lower_to_pointwise(self, ex: LogicNode, einsums: list[Einsum]) -> PointwiseN return PointwiseLiteral(val=value) case Aggregate(_, _, _, _): # aggregate has to be computed seperatley as it's own einsum aggregate_einsum_alias = self.get_next_alias() - einsums.append(self.lower_to_einsum(ex, einsums).rename(aggregate_einsum_alias)) + einsums.append(self.rename_einsum(self.lower_to_einsum(ex, einsums, parameters, definitions), aggregate_einsum_alias, definitions)) return PointwiseAccess(alias=aggregate_einsum_alias, idxs=tuple(ex.fields)) case _: raise Exception(f"Unrecognized logic: {ex}") @@ -261,15 +284,17 @@ def __init__(self): self.el = EinsumLowerer() def __call__(self, prgm: Plan): - return self.el(prgm) + parameters = {} + definitions = {} + return self.el(prgm, parameters, definitions), parameters, definitions def einsum_scheduler(plan: Plan): optimized_prgm = optimize(plan) - interpreter = EinsumCompiler() - return interpreter(optimized_prgm) + compiler = EinsumCompiler() + return compiler(optimized_prgm) -class EinsumPrinterContext(Context): +class EinsumPrinterContext: def print_indicies(self, idxs: tuple[Field, ...]): return ", ".join([str(idx) for idx in idxs]) @@ -279,7 +304,6 @@ def print_reducer(self, reducer: Callable): operator.add: "+=", operator.sub: "-=", operator.mul: "*=", - operator.div: "/=", operator.truediv: "/=", operator.mod: "%=", operator.pow: "**=", @@ -292,34 +316,41 @@ def print_reducer(self, reducer: Callable): } return str_map[reducer] - def print_pointwise_op(self, op: Callable): + def print_pointwise_op_callable(self, op: Callable): str_map = { operator.add: "+", operator.sub: "-", operator.mul: "*", - operator.div: "/", operator.truediv: "/", operator.mod: "%", operator.pow: "**", } return str_map[op] + def print_pointwise_op(self, pointwise_op: PointwiseOp): + if is_commutative(pointwise_op.op) == False: + return f"({pointwise_op.args[0]} {self.print_pointwise_op_callable(pointwise_op.op)} {pointwise_op.args[1]})" + return f"({f" {self.print_pointwise_op_callable(pointwise_op.op)} ".join(self.print_pointwise_expr(arg) for arg in pointwise_op.args)})" + def print_pointwise_expr(self, pointwise_expr: PointwiseNode): match pointwise_expr: case PointwiseAccess(alias, idxs): return f"{alias}[{self.print_indicies(idxs)}]" - case PointwiseOp(op, args): - return f"{self.print_pointwise_op(op)}({', '.join(self.print_pointwise_expr(arg) for arg in args)})" + case PointwiseOp(_, __): + return self.print_pointwise_op(pointwise_expr) case PointwiseLiteral(val): return str(val) def print_einsum(self, einsum: Einsum): return f"{einsum.output_alias}[{self.print_indicies(einsum.output_fields)}] {self.print_reducer(einsum.reduceOp)} {self.print_pointwise_expr(einsum.pointwise_expr)}" + def print_return_value(self, return_value: Einsum | str): + return return_value if isinstance(return_value, str) else self.print_einsum(return_value) + def print_einsum_plan(self, einsum_plan: EinsumPlan): - if einsum_plan.returnValue is None: + if not einsum_plan.returnValues: return "\n".join([self.print_einsum(einsum) for einsum in einsum_plan.bodies]) - return f"{"\n".join([self.print_einsum(einsum) for einsum in einsum_plan.bodies])}\nreturn {self.print_einsum(einsum_plan.returnValue)}" + return f"{"\n".join([self.print_einsum(einsum) for einsum in einsum_plan.bodies])}\nreturn {", ".join([self.print_return_value(return_value) for return_value in einsum_plan.returnValues])}" def __call__(self, prgm: EinsumPlan): return self.print_einsum_plan(prgm) \ No newline at end of file From 6eabb03296e44f9e8b3b8b5b11ecc15f0115bf49 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Fri, 12 Sep 2025 17:43:13 -0400 Subject: [PATCH 15/26] * Removed unused imports from einsum.py --- src/finchlite/autoschedule/einsum.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index f14af797..c8ee6b32 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from abc import ABC import operator -import traceback from typing import Callable, Self from finchlite.finch_logic import LogicNode, Field, Plan, Query, Alias, Literal, Relabel @@ -9,7 +8,6 @@ from finchlite.symbolic import Term, TermTree from finchlite.autoschedule import optimize from finchlite.algebra import is_commutative, overwrite, init_value -from finchlite.symbolic import Context @dataclass(eq=True, frozen=True) class PointwiseNode(Term, ABC): From 71d7f046f2e80838520711ab35e4222a30f8135c Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 22 Sep 2025 16:21:47 -0400 Subject: [PATCH 16/26] * Added support for printing promote max and promote min reduction operators in einsums --- src/finchlite/autoschedule/einsum.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index c8ee6b32..9455b367 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -7,7 +7,7 @@ from finchlite.finch_logic.nodes import Aggregate, MapJoin, Produces, Reorder, Table from finchlite.symbolic import Term, TermTree from finchlite.autoschedule import optimize -from finchlite.algebra import is_commutative, overwrite, init_value +from finchlite.algebra import is_commutative, overwrite, init_value, promote_max, promote_min @dataclass(eq=True, frozen=True) class PointwiseNode(Term, ABC): @@ -311,6 +311,8 @@ def print_reducer(self, reducer: Callable): operator.floordiv: "//=", operator.mod: "%=", operator.pow: "**=", + promote_max: "max=", + promote_min: "min=", } return str_map[reducer] From aa73d5b735ac1d14b034a783fadbef794ac71c1f Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 22 Sep 2025 17:41:37 -0400 Subject: [PATCH 17/26] * Added support for EinsumInterpreter based Einsum scheduler --- src/finchlite/autoschedule/__init__.py | 4 +--- src/finchlite/autoschedule/einsum.py | 25 +++++++++++++++++-------- src/finchlite/autoschedule/optimize.py | 4 ++++ src/finchlite/interface/fuse.py | 12 +++++++++++- 4 files changed, 33 insertions(+), 12 deletions(-) diff --git a/src/finchlite/autoschedule/__init__.py b/src/finchlite/autoschedule/__init__.py index ab3eb3b3..64356d98 100644 --- a/src/finchlite/autoschedule/__init__.py +++ b/src/finchlite/autoschedule/__init__.py @@ -40,8 +40,7 @@ ) from .einsum import ( Einsum, - EinsumPlan, - einsum_scheduler + EinsumPlan ) __all__ = [ @@ -50,7 +49,6 @@ "DefaultLogicOptimizer", "Einsum", "EinsumPlan", - "einsum_scheduler", "Field", "Literal", "LogicCompiler", diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 9455b367..768ff4e1 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -5,6 +5,7 @@ from finchlite.finch_logic import LogicNode, Field, Plan, Query, Alias, Literal, Relabel from finchlite.finch_logic.nodes import Aggregate, MapJoin, Produces, Reorder, Table +from finchlite.interface.lazy import defer from finchlite.symbolic import Term, TermTree from finchlite.autoschedule import optimize from finchlite.algebra import is_commutative, overwrite, init_value, promote_max, promote_min @@ -284,13 +285,9 @@ def __init__(self): def __call__(self, prgm: Plan): parameters = {} definitions = {} - return self.el(prgm, parameters, definitions), parameters, definitions - -def einsum_scheduler(plan: Plan): - optimized_prgm = optimize(plan) - - compiler = EinsumCompiler() - return compiler(optimized_prgm) + einsums = self.el(prgm, parameters, definitions) + + return einsums, parameters, definitions class EinsumPrinterContext: def print_indicies(self, idxs: tuple[Field, ...]): @@ -353,4 +350,16 @@ def print_einsum_plan(self, einsum_plan: EinsumPlan): return f"{"\n".join([self.print_einsum(einsum) for einsum in einsum_plan.bodies])}\nreturn {", ".join([self.print_return_value(return_value) for return_value in einsum_plan.returnValues])}" def __call__(self, prgm: EinsumPlan): - return self.print_einsum_plan(prgm) \ No newline at end of file + return self.print_einsum_plan(prgm) + +class EinsumInterpreter: + def __call__(self, einsum_plan: EinsumPlan, parameters: dict[str, Table]): + import numpy as np + for (str, table) in parameters.items(): + print(f"Parameter: {str} = {table}") + + print(einsum_plan) + + # Return the actual numpy array, not a Table object + # The scheduler must return a tuple of actual values + return (np.arange(6, dtype=np.float32).reshape(2, 3),) \ No newline at end of file diff --git a/src/finchlite/autoschedule/optimize.py b/src/finchlite/autoschedule/optimize.py index b689a06f..7d945dba 100644 --- a/src/finchlite/autoschedule/optimize.py +++ b/src/finchlite/autoschedule/optimize.py @@ -6,6 +6,7 @@ from typing import TypeVar, overload from finchlite.algebra.algebra import is_annihilator, is_distributive, is_identity +from finchlite.autoschedule.einsum import EinsumCompiler from ..finch_logic import ( Aggregate, @@ -776,6 +777,9 @@ class DefaultLogicOptimizer: def __init__(self, ctx: LogicCompiler): self.ctx = ctx + def __init__(self, ctx: EinsumCompiler): + self.ctx = ctx + def __call__(self, prgm: LogicNode): prgm = optimize(prgm) return self.ctx(prgm) diff --git a/src/finchlite/interface/fuse.py b/src/finchlite/interface/fuse.py index d8cbdd77..3e519066 100644 --- a/src/finchlite/interface/fuse.py +++ b/src/finchlite/interface/fuse.py @@ -51,6 +51,7 @@ """ from ..autoschedule import DefaultLogicOptimizer, LogicCompiler +from ..autoschedule.einsum import EinsumCompiler, EinsumInterpreter from ..finch_logic import Alias, FinchLogicInterpreter, Plan, Produces, Query from ..finch_notation import NotationInterpreter from ..symbolic import gensym @@ -59,13 +60,22 @@ _DEFAULT_SCHEDULER = None -def set_default_scheduler(*, ctx=None, interpret_logic=False): +def set_default_scheduler(*, ctx=None, interpret_logic=False, interpret_einsum=False): global _DEFAULT_SCHEDULER if ctx is not None: _DEFAULT_SCHEDULER = ctx elif interpret_logic: _DEFAULT_SCHEDULER = FinchLogicInterpreter() + elif interpret_einsum: + optimizer = DefaultLogicOptimizer(EinsumCompiler()) + einsum_interpreter = EinsumInterpreter() + + def fn_compile(plan): + einsums, parameters, _ = optimizer(plan) + return einsum_interpreter(einsums, parameters) + + _DEFAULT_SCHEDULER = fn_compile else: optimizer = DefaultLogicOptimizer(LogicCompiler()) ntn_interp = NotationInterpreter() From 5afab01ae480daa79df7c96caedb2abfc8f5eb01 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 22 Sep 2025 22:24:24 -0400 Subject: [PATCH 18/26] * Added sparse tensor type designed specifically for use with einsum autoscheduler --- src/finchlite/autoschedule/einsum.py | 72 +++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 768ff4e1..050a6130 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -8,7 +8,9 @@ from finchlite.interface.lazy import defer from finchlite.symbolic import Term, TermTree from finchlite.autoschedule import optimize -from finchlite.algebra import is_commutative, overwrite, init_value, promote_max, promote_min +from finchlite.algebra import is_commutative, overwrite, init_value, promote_max, promote_min, TensorFType, register_property +from finchlite.interface.eager import EagerTensor +import numpy as np @dataclass(eq=True, frozen=True) class PointwiseNode(Term, ABC): @@ -352,6 +354,74 @@ def print_einsum_plan(self, einsum_plan: EinsumPlan): def __call__(self, prgm: EinsumPlan): return self.print_einsum_plan(prgm) +class SparseTensorFType(TensorFType): + def __init__(self, shape: tuple, element_type: type, fill_value): + self.shape = shape + self.element_type = element_type + self.fill_value = fill_value + + def __eq__(self, other): + if not isinstance(other, SparseTensorFType): + return False + return self.shape == other.shape and self.element_type == other.element_type and self.fill_value == other.fill_value + + def __hash__(self): + return hash((self.shape, self.element_type, self.fill_value)) + + @property + def ndim(self): + return len(self.shape) + + @property + def shape_type(self): + return self.shape + + @property + def fill_value(self): + return self.fill_value + +# currently implemented with COO tensor +class SparseTensor(EagerTensor): + def __init__(self, coords: tuple, data: np.ndarray, shape: tuple, element_type=None, fill_value=0.0): + self.coords = coords + self.data = data + self.shape = shape + self.element_type = element_type + self.fill_value = fill_value + + # converts an eager tensor to a sparse tensor + @classmethod + def from_dense_tensor(cls, dense_tensor: EagerTensor): + coords = np.where(dense_tensor.data != dense_tensor.fill_value) + data = dense_tensor.data[coords] + shape = dense_tensor.shape + element_type = dense_tensor.element_type + fill_value = dense_tensor.fill_value + return cls(coords, data, shape, element_type, fill_value) + + @property + def ftype(self): + return SparseTensorFType(self.shape, self.element_type, self.fill_value) + + @property + def shape(self): + return self.shape + + # calculates the ratio of non-zero elements to the total number of elements + @property + def density(self): + return len(self.data) / np.prod(self.shape) + + def __getitem__(self, idx): + if not isinstance(idx, tuple): + raise ValueError("Index must be a tuple") + + if len(idx) != self.ndim: + raise ValueError(f"Index must have {self.ndim} dimensions") + + #return the first element that matches the index + return self.data[np.all(self.coords == idx, axis=1)][0] + class EinsumInterpreter: def __call__(self, einsum_plan: EinsumPlan, parameters: dict[str, Table]): import numpy as np From 9318c447b7fd5613bba8a0f30ac187fcdb2b7eff Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 23 Sep 2025 11:22:45 -0400 Subject: [PATCH 19/26] * Added support for SparseTensor type to automatically use einsum scheduler * Fixed circular import issues --- src/finchlite/algebra/tensor.py | 1 - src/finchlite/autoschedule/__init__.py | 6 +- src/finchlite/autoschedule/einsum.py | 86 ++---------- src/finchlite/autoschedule/optimize.py | 7 +- src/finchlite/autoschedule/sparse_tensor.py | 88 ++++++++++++ src/finchlite/interface/eager.py | 141 +++++++++++--------- src/finchlite/interface/fuse.py | 12 +- 7 files changed, 181 insertions(+), 160 deletions(-) create mode 100644 src/finchlite/autoschedule/sparse_tensor.py diff --git a/src/finchlite/algebra/tensor.py b/src/finchlite/algebra/tensor.py index 2d5a6aa8..20c29c06 100644 --- a/src/finchlite/algebra/tensor.py +++ b/src/finchlite/algebra/tensor.py @@ -6,7 +6,6 @@ from ..algebra import register_property from ..symbolic import FType, FTyped, ftype - class TensorFType(FType, ABC): @property def ndim(self) -> int: diff --git a/src/finchlite/autoschedule/__init__.py b/src/finchlite/autoschedule/__init__.py index 64356d98..9f8f0faa 100644 --- a/src/finchlite/autoschedule/__init__.py +++ b/src/finchlite/autoschedule/__init__.py @@ -40,7 +40,9 @@ ) from .einsum import ( Einsum, - EinsumPlan + EinsumPlan, + EinsumCompiler, + EinsumScheduler ) __all__ = [ @@ -49,6 +51,8 @@ "DefaultLogicOptimizer", "Einsum", "EinsumPlan", + "EinsumCompiler", + "EinsumScheduler", "Field", "Literal", "LogicCompiler", diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 050a6130..b96eecda 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -5,11 +5,8 @@ from finchlite.finch_logic import LogicNode, Field, Plan, Query, Alias, Literal, Relabel from finchlite.finch_logic.nodes import Aggregate, MapJoin, Produces, Reorder, Table -from finchlite.interface.lazy import defer from finchlite.symbolic import Term, TermTree -from finchlite.autoschedule import optimize -from finchlite.algebra import is_commutative, overwrite, init_value, promote_max, promote_min, TensorFType, register_property -from finchlite.interface.eager import EagerTensor +from finchlite.algebra import is_commutative, overwrite, init_value, promote_max, promote_min import numpy as np @dataclass(eq=True, frozen=True) @@ -287,9 +284,9 @@ def __init__(self): def __call__(self, prgm: Plan): parameters = {} definitions = {} - einsums = self.el(prgm, parameters, definitions) + einsum_plan = self.el(prgm, parameters, definitions) - return einsums, parameters, definitions + return einsum_plan, parameters, definitions class EinsumPrinterContext: def print_indicies(self, idxs: tuple[Field, ...]): @@ -354,82 +351,19 @@ def print_einsum_plan(self, einsum_plan: EinsumPlan): def __call__(self, prgm: EinsumPlan): return self.print_einsum_plan(prgm) -class SparseTensorFType(TensorFType): - def __init__(self, shape: tuple, element_type: type, fill_value): - self.shape = shape - self.element_type = element_type - self.fill_value = fill_value - - def __eq__(self, other): - if not isinstance(other, SparseTensorFType): - return False - return self.shape == other.shape and self.element_type == other.element_type and self.fill_value == other.fill_value - - def __hash__(self): - return hash((self.shape, self.element_type, self.fill_value)) +class EinsumScheduler: + def __init__(self, ctx: EinsumCompiler): + self.ctx = ctx - @property - def ndim(self): - return len(self.shape) - - @property - def shape_type(self): - return self.shape - - @property - def fill_value(self): - return self.fill_value - -# currently implemented with COO tensor -class SparseTensor(EagerTensor): - def __init__(self, coords: tuple, data: np.ndarray, shape: tuple, element_type=None, fill_value=0.0): - self.coords = coords - self.data = data - self.shape = shape - self.element_type = element_type - self.fill_value = fill_value - - # converts an eager tensor to a sparse tensor - @classmethod - def from_dense_tensor(cls, dense_tensor: EagerTensor): - coords = np.where(dense_tensor.data != dense_tensor.fill_value) - data = dense_tensor.data[coords] - shape = dense_tensor.shape - element_type = dense_tensor.element_type - fill_value = dense_tensor.fill_value - return cls(coords, data, shape, element_type, fill_value) - - @property - def ftype(self): - return SparseTensorFType(self.shape, self.element_type, self.fill_value) - - @property - def shape(self): - return self.shape - - # calculates the ratio of non-zero elements to the total number of elements - @property - def density(self): - return len(self.data) / np.prod(self.shape) - - def __getitem__(self, idx): - if not isinstance(idx, tuple): - raise ValueError("Index must be a tuple") - - if len(idx) != self.ndim: - raise ValueError(f"Index must have {self.ndim} dimensions") - - #return the first element that matches the index - return self.data[np.all(self.coords == idx, axis=1)][0] + def __call__(self, prgm: LogicNode): + einsum_plan, parameters, _ = self.ctx(prgm) + return self.interpret(einsum_plan, parameters) -class EinsumInterpreter: - def __call__(self, einsum_plan: EinsumPlan, parameters: dict[str, Table]): + def interpret(self, einsum_plan: EinsumPlan, parameters: dict[str, Table]): import numpy as np for (str, table) in parameters.items(): print(f"Parameter: {str} = {table}") print(einsum_plan) - # Return the actual numpy array, not a Table object - # The scheduler must return a tuple of actual values return (np.arange(6, dtype=np.float32).reshape(2, 3),) \ No newline at end of file diff --git a/src/finchlite/autoschedule/optimize.py b/src/finchlite/autoschedule/optimize.py index 7d945dba..9fdb5928 100644 --- a/src/finchlite/autoschedule/optimize.py +++ b/src/finchlite/autoschedule/optimize.py @@ -6,7 +6,7 @@ from typing import TypeVar, overload from finchlite.algebra.algebra import is_annihilator, is_distributive, is_identity -from finchlite.autoschedule.einsum import EinsumCompiler +#from finchlite.autoschedule.einsum import EinsumCompiler from ..finch_logic import ( Aggregate, @@ -774,10 +774,7 @@ def rule_1(ex): class DefaultLogicOptimizer: - def __init__(self, ctx: LogicCompiler): - self.ctx = ctx - - def __init__(self, ctx: EinsumCompiler): + def __init__(self, ctx): self.ctx = ctx def __call__(self, prgm: LogicNode): diff --git a/src/finchlite/autoschedule/sparse_tensor.py b/src/finchlite/autoschedule/sparse_tensor.py new file mode 100644 index 00000000..8d86228f --- /dev/null +++ b/src/finchlite/autoschedule/sparse_tensor.py @@ -0,0 +1,88 @@ +from typing import override +from finchlite.algebra import TensorFType +from finchlite.interface.eager import EagerTensor +import numpy as np + +class SparseTensorFType(TensorFType): + def __init__(self, shape: tuple, element_type: type): + self.shape = shape + self._element_type = element_type + + def __eq__(self, other): + if not isinstance(other, SparseTensorFType): + return False + return self.shape == other.shape and self.element_type == other.element_type + + def __hash__(self): + return hash((self.shape, self.element_type)) + + @property + def ndim(self): + return len(self.shape) + + @property + def shape_type(self): + return self.shape + + @property + def element_type(self): + return self._element_type + + @property + def fill_value(self): + return 0 + +# currently implemented with COO tensor +class SparseTensor(EagerTensor): + def __init__(self, data: np.array, coords: np.ndarray, shape: tuple, element_type=np.float64): + self.coords = coords + self.data = data + self._shape = shape + self._element_type = element_type + + # converts an eager tensor to a sparse tensor + @classmethod + def from_dense_tensor(cls, dense_tensor: np.ndarray): + + coords = np.where(dense_tensor != 0) + data = dense_tensor[coords] + shape = dense_tensor.shape + element_type = dense_tensor.dtype.type # Get the type, not the dtype + # Convert coords from tuple of arrays to array of coordinates + coords_array = np.array(coords).T + return cls(data, coords_array, shape, element_type) + + @property + def ftype(self): + return SparseTensorFType(self.shape, self._element_type) + + @property + def shape(self): + return self._shape + + @property + def ndim(self) -> int: + return len(self._shape) + + # calculates the ratio of non-zero elements to the total number of elements + @property + def density(self): + return self.coords.shape[0] / np.prod(self.shape) + + def __getitem__(self, idx: tuple): + if len(idx) != self.ndim: + raise ValueError(f"Index must have {self.ndim} dimensions") + + # coords is a 2D array where each row is a coordinate + mask = np.all(self.coords == idx, axis=1) + matching_indices = np.where(mask)[0] + + if len(matching_indices) > 0: + return self.data[matching_indices[0]] + return 0 + + def to_dense(self) -> np.ndarray: + dense_tensor = np.zeros(self.shape, dtype=self._element_type) + for i in range(self.coords.shape[0]): + dense_tensor[tuple(self.coords[i])] = self.data[i] + return dense_tensor \ No newline at end of file diff --git a/src/finchlite/interface/eager.py b/src/finchlite/interface/eager.py index 66f3a204..5904eea9 100644 --- a/src/finchlite/interface/eager.py +++ b/src/finchlite/interface/eager.py @@ -8,7 +8,6 @@ from .fuse import compute from .overrides import OverrideTensor - class EagerTensor(OverrideTensor, ABC): def override_module(self): return sys.modules[__name__] @@ -208,11 +207,21 @@ def __logical_not__(self): register_property(EagerTensor, "asarray", "__attr__", lambda x: x) +def get_eager_scheduler(*args): + from finchlite.autoschedule.sparse_tensor import SparseTensor + from finchlite.autoschedule import EinsumScheduler, DefaultLogicOptimizer, EinsumCompiler + + for arg in args: + if isinstance(arg, SparseTensor): + return DefaultLogicOptimizer(EinsumScheduler(EinsumCompiler())) + + return None + def permute_dims(arg, /, axis: tuple[int, ...]): if isinstance(arg, lazy.LazyTensor): return lazy.permute_dims(arg, axis=axis) - return compute(lazy.permute_dims(arg, axis=axis)) + return compute(lazy.permute_dims(arg, axis=axis), ctx=get_eager_scheduler(arg)) def expand_dims( @@ -222,7 +231,7 @@ def expand_dims( ): if isinstance(x, lazy.LazyTensor): return lazy.expand_dims(x, axis=axis) - return compute(lazy.expand_dims(x, axis=axis)) + return compute(lazy.expand_dims(x, axis=axis), ctx=get_eager_scheduler(x)) def squeeze( @@ -232,7 +241,7 @@ def squeeze( ): if isinstance(x, lazy.LazyTensor): return lazy.squeeze(x, axis=axis) - return compute(lazy.squeeze(x, axis=axis)) + return compute(lazy.squeeze(x, axis=axis), ctx=get_eager_scheduler(x)) def reduce( @@ -248,7 +257,8 @@ def reduce( if isinstance(x, lazy.LazyTensor): return lazy.reduce(op, x, axis=axis, dtype=dtype, keepdims=keepdims, init=init) return compute( - lazy.reduce(op, x, axis=axis, dtype=dtype, keepdims=keepdims, init=init) + lazy.reduce(op, x, axis=axis, dtype=dtype, keepdims=keepdims, init=init), + ctx=get_eager_scheduler(x) ) @@ -262,7 +272,7 @@ def sum( ): if isinstance(x, lazy.LazyTensor): return lazy.sum(x, axis=axis, dtype=dtype, keepdims=keepdims) - return compute(lazy.sum(x, axis=axis, dtype=dtype, keepdims=keepdims)) + return compute(lazy.sum(x, axis=axis, dtype=dtype, keepdims=keepdims), ctx=get_eager_scheduler(x)) def prod( @@ -275,49 +285,48 @@ def prod( ): if isinstance(x, lazy.LazyTensor): return lazy.prod(x, axis=axis, dtype=dtype, keepdims=keepdims) - return compute(lazy.prod(x, axis=axis, dtype=dtype, keepdims=keepdims)) + return compute(lazy.prod(x, axis=axis, dtype=dtype, keepdims=keepdims), ctx=get_eager_scheduler(x)) def elementwise(f: Callable, *args): if builtins.any(isinstance(arg, lazy.LazyTensor) for arg in args): return lazy.elementwise(f, *args) - return compute(lazy.elementwise(f, *args)) - + return compute(lazy.elementwise(f, *args), ctx=get_eager_scheduler(*args)) def add(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.add(x1, x2) - return compute(lazy.add(x1, x2)) + return compute(lazy.add(x1, x2), ctx=get_eager_scheduler(x1, x2)) def subtract(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.subtract(x1, x2) - return compute(lazy.subtract(x1, x2)) + return compute(lazy.subtract(x1, x2), ctx=get_eager_scheduler(x1, x2)) def multiply(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.multiply(x1, x2) - return compute(lazy.multiply(x1, x2)) + return compute(lazy.multiply(x1, x2), ctx=get_eager_scheduler(x1, x2)) def abs(x): if isinstance(x, lazy.LazyTensor): return lazy.abs(x) - return compute(lazy.abs(x)) + return compute(lazy.abs(x), ctx=get_eager_scheduler(x)) def positive(x): if isinstance(x, lazy.LazyTensor): return lazy.positive(x) - return compute(lazy.positive(x)) + return compute(lazy.positive(x), ctx=get_eager_scheduler(x)) def negative(x): if isinstance(x, lazy.LazyTensor): return lazy.negative(x) - return compute(lazy.negative(x)) + return compute(lazy.negative(x), ctx=get_eager_scheduler(x)) def matmul(x1, x2, /): @@ -330,7 +339,7 @@ def matmul(x1, x2, /): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.matmul(x1, x2) c = lazy.matmul(x1, x2) - return compute(c) + return compute(c, ctx=get_eager_scheduler(x1, x2)) def matrix_transpose(x, /): @@ -339,67 +348,67 @@ def matrix_transpose(x, /): """ if isinstance(x, lazy.LazyTensor): return lazy.matrix_transpose(x) - return compute(lazy.matrix_transpose(x)) + return compute(lazy.matrix_transpose(x), ctx=get_eager_scheduler(x)) def bitwise_inverse(x): if isinstance(x, lazy.LazyTensor): return lazy.bitwise_inverse(x) - return compute(lazy.bitwise_inverse(x)) + return compute(lazy.bitwise_inverse(x), ctx=get_eager_scheduler(x)) def bitwise_and(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.bitwise_and(x1, x2) - return compute(lazy.bitwise_and(x1, x2)) + return compute(lazy.bitwise_and(x1, x2), ctx=get_eager_scheduler(x1, x2)) def bitwise_left_shift(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.bitwise_left_shift(x1, x2) - return compute(lazy.bitwise_left_shift(x1, x2)) + return compute(lazy.bitwise_left_shift(x1, x2), ctx=get_eager_scheduler(x1, x2)) def bitwise_or(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.bitwise_or(x1, x2) - return compute(lazy.bitwise_or(x1, x2)) + return compute(lazy.bitwise_or(x1, x2), ctx=get_eager_scheduler(x1, x2)) def bitwise_right_shift(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.bitwise_right_shift(x1, x2) - return compute(lazy.bitwise_right_shift(x1, x2)) + return compute(lazy.bitwise_right_shift(x1, x2), ctx=get_eager_scheduler(x1, x2)) def bitwise_xor(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.bitwise_xor(x1, x2) - return compute(lazy.bitwise_xor(x1, x2)) + return compute(lazy.bitwise_xor(x1, x2), ctx=get_eager_scheduler(x1, x2)) def truediv(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.truediv(x1, x2) - return compute(lazy.truediv(x1, x2)) + return compute(lazy.truediv(x1, x2), ctx=get_eager_scheduler(x1, x2)) def floordiv(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.floordiv(x1, x2) - return compute(lazy.floordiv(x1, x2)) + return compute(lazy.floordiv(x1, x2), ctx=get_eager_scheduler(x1, x2)) def mod(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.mod(x1, x2) - return compute(lazy.mod(x1, x2)) + return compute(lazy.mod(x1, x2), ctx=get_eager_scheduler(x1, x2)) def pow(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.pow(x1, x2) - return compute(lazy.pow(x1, x2)) + return compute(lazy.pow(x1, x2), ctx=get_eager_scheduler(x1, x2)) def tensordot(x1, x2, /, *, axes: int | tuple[Sequence[int], Sequence[int]]): @@ -411,7 +420,7 @@ def tensordot(x1, x2, /, *, axes: int | tuple[Sequence[int], Sequence[int]]): """ if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.tensordot(x1, x2, axes=axes) - return compute(lazy.tensordot(x1, x2, axes=axes)) + return compute(lazy.tensordot(x1, x2, axes=axes), ctx=get_eager_scheduler(x1, x2)) def vecdot(x1, x2, /, *, axis=-1): @@ -434,31 +443,31 @@ def vecdot(x1, x2, /, *, axis=-1): """ if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.vecdot(x1, x2, axis=axis) - return compute(lazy.vecdot(x1, x2, axis=axis)) + return compute(lazy.vecdot(x1, x2, axis=axis), ctx=get_eager_scheduler(x1, x2)) def any(x, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False): if isinstance(x, lazy.LazyTensor): return lazy.any(x, axis=axis, keepdims=keepdims) - return compute(lazy.any(x, axis=axis, keepdims=keepdims)) + return compute(lazy.any(x, axis=axis, keepdims=keepdims), ctx=get_eager_scheduler(x)) def all(x, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False): if isinstance(x, lazy.LazyTensor): return lazy.all(x, axis=axis, keepdims=keepdims) - return compute(lazy.all(x, axis=axis, keepdims=keepdims)) + return compute(lazy.all(x, axis=axis, keepdims=keepdims), ctx=get_eager_scheduler(x)) def min(x, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False): if isinstance(x, lazy.LazyTensor): return lazy.min(x, axis=axis, keepdims=keepdims) - return compute(lazy.min(x, axis=axis, keepdims=keepdims)) + return compute(lazy.min(x, axis=axis, keepdims=keepdims), ctx=get_eager_scheduler(x)) def max(x, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False): if isinstance(x, lazy.LazyTensor): return lazy.max(x, axis=axis, keepdims=keepdims) - return compute(lazy.max(x, axis=axis, keepdims=keepdims)) + return compute(lazy.max(x, axis=axis, keepdims=keepdims), ctx=get_eager_scheduler(x)) # manipulation functions: @@ -484,7 +493,7 @@ def broadcast_to(x, /, shape: Sequence[int]): shape = tuple(shape) # Ensure shape is a tuple for consistency if isinstance(x, lazy.LazyTensor): return lazy.broadcast_to(x, shape=shape) - return compute(lazy.broadcast_to(x, shape=shape)) + return compute(lazy.broadcast_to(x, shape=shape), ctx=get_eager_scheduler(x)) def broadcast_arrays(*args): @@ -505,7 +514,7 @@ def broadcast_arrays(*args): if builtins.any(isinstance(arg, lazy.LazyTensor) for arg in args): return lazy.broadcast_arrays(*args) # compute can take in a list of LazyTensors - return compute(lazy.broadcast_arrays(*args)) + return compute(lazy.broadcast_arrays(*args), ctx=get_eager_scheduler(*args)) def concat(arrays: tuple | list, /, *, axis: int | None = 0): @@ -528,7 +537,7 @@ def concat(arrays: tuple | list, /, *, axis: int | None = 0): """ if builtins.any(isinstance(arr, lazy.LazyTensor) for arr in arrays): return lazy.concat(arrays, axis=axis) - return compute(lazy.concat(arrays, axis=axis)) + return compute(lazy.concat(arrays, axis=axis), ctx=get_eager_scheduler(*arrays)) def moveaxis(x, source: int | tuple[int, ...], destination: int | tuple[int, ...], /): @@ -549,7 +558,7 @@ def moveaxis(x, source: int | tuple[int, ...], destination: int | tuple[int, ... """ if isinstance(x, lazy.LazyTensor): return lazy.moveaxis(x, source, destination) - return compute(lazy.moveaxis(x, source, destination)) + return compute(lazy.moveaxis(x, source, destination), ctx=get_eager_scheduler(x)) def stack(arrays: Sequence, /, *, axis: int = 0): @@ -570,7 +579,7 @@ def stack(arrays: Sequence, /, *, axis: int = 0): """ if builtins.any(isinstance(arr, lazy.LazyTensor) for arr in arrays): return lazy.stack(arrays, axis=axis) - return compute(lazy.stack(arrays, axis=axis)) + return compute(lazy.stack(arrays, axis=axis), ctx=get_eager_scheduler(*arrays)) def split_dims(x, axis: int, shape: tuple): @@ -603,7 +612,7 @@ def split_dims(x, axis: int, shape: tuple): """ if isinstance(x, lazy.LazyTensor): return lazy.split_dims(x, axis, shape) - return compute(lazy.split_dims(x, axis, shape)) + return compute(lazy.split_dims(x, axis, shape), ctx=get_eager_scheduler(x)) def combine_dims(x, axes: tuple[int, ...]): @@ -638,7 +647,7 @@ def combine_dims(x, axes: tuple[int, ...]): """ if isinstance(x, lazy.LazyTensor): return lazy.combine_dims(x, axes) - return compute(lazy.combine_dims(x, axes)) + return compute(lazy.combine_dims(x, axes), ctx=get_eager_scheduler(x)) def flatten(x): @@ -665,146 +674,146 @@ def flatten(x): """ if isinstance(x, lazy.LazyTensor): return lazy.flatten(x) - return compute(lazy.flatten(x)) + return compute(lazy.flatten(x), ctx=get_eager_scheduler(x)) # trigonometric functions: def sin(x): if isinstance(x, lazy.LazyTensor): return lazy.sin(x) - return compute(lazy.sin(x)) + return compute(lazy.sin(x), ctx=get_eager_scheduler(x)) def sinh(x): if isinstance(x, lazy.LazyTensor): return lazy.sinh(x) - return compute(lazy.sinh(x)) + return compute(lazy.sinh(x), ctx=get_eager_scheduler(x)) def cos(x): if isinstance(x, lazy.LazyTensor): return lazy.cos(x) - return compute(lazy.cos(x)) + return compute(lazy.cos(x), ctx=get_eager_scheduler(x)) def cosh(x): if isinstance(x, lazy.LazyTensor): return lazy.cosh(x) - return compute(lazy.cosh(x)) + return compute(lazy.cosh(x), ctx=get_eager_scheduler(x)) def tan(x): if isinstance(x, lazy.LazyTensor): return lazy.tan(x) - return compute(lazy.tan(x)) + return compute(lazy.tan(x), ctx=get_eager_scheduler(x)) def tanh(x): if isinstance(x, lazy.LazyTensor): return lazy.tanh(x) - return compute(lazy.tanh(x)) + return compute(lazy.tanh(x), ctx=get_eager_scheduler(x)) def asin(x): if isinstance(x, lazy.LazyTensor): return lazy.asin(x) - return compute(lazy.asin(x)) + return compute(lazy.asin(x), ctx=get_eager_scheduler(x)) def asinh(x): if isinstance(x, lazy.LazyTensor): return lazy.asinh(x) - return compute(lazy.asinh(x)) + return compute(lazy.asinh(x), ctx=get_eager_scheduler(x)) def acos(x): if isinstance(x, lazy.LazyTensor): return lazy.acos(x) - return compute(lazy.acos(x)) + return compute(lazy.acos(x), ctx=get_eager_scheduler(x)) def acosh(x): if isinstance(x, lazy.LazyTensor): return lazy.acosh(x) - return compute(lazy.acosh(x)) + return compute(lazy.acosh(x), ctx=get_eager_scheduler(x)) def atan(x): if isinstance(x, lazy.LazyTensor): return lazy.atan(x) - return compute(lazy.atan(x)) + return compute(lazy.atan(x), ctx=get_eager_scheduler(x)) def atanh(x): if isinstance(x, lazy.LazyTensor): return lazy.atanh(x) - return compute(lazy.atanh(x)) + return compute(lazy.atanh(x), ctx=get_eager_scheduler(x)) def atan2(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.atan2(x1, x2) - return compute(lazy.atan2(x1, x2)) + return compute(lazy.atan2(x1, x2), ctx=get_eager_scheduler(x1, x2)) def log(x): if isinstance(x, lazy.LazyTensor): return lazy.log(x) - return compute(lazy.log(x)) + return compute(lazy.log(x), ctx=get_eager_scheduler(x)) def log1p(x): if isinstance(x, lazy.LazyTensor): return lazy.log1p(x) - return compute(lazy.log1p(x)) + return compute(lazy.log1p(x), ctx=get_eager_scheduler(x)) def log2(x): if isinstance(x, lazy.LazyTensor): return lazy.log2(x) - return compute(lazy.log2(x)) + return compute(lazy.log2(x), ctx=get_eager_scheduler(x)) def log10(x): if isinstance(x, lazy.LazyTensor): return lazy.log10(x) - return compute(lazy.log10(x)) + return compute(lazy.log10(x), ctx=get_eager_scheduler(x)) def logaddexp(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.logaddexp(x1, x2) - return compute(lazy.logaddexp(x1, x2)) + return compute(lazy.logaddexp(x1, x2), ctx=get_eager_scheduler(x1, x2)) def logical_and(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.logical_and(x1, x2) - return compute(lazy.logical_and(x1, x2)) + return compute(lazy.logical_and(x1, x2), ctx=get_eager_scheduler(x1, x2)) def logical_or(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.logical_or(x1, x2) - return compute(lazy.logical_or(x1, x2)) + return compute(lazy.logical_or(x1, x2), ctx=get_eager_scheduler(x1, x2)) def logical_xor(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.logical_xor(x1, x2) - return compute(lazy.logical_xor(x1, x2)) + return compute(lazy.logical_xor(x1, x2), ctx=get_eager_scheduler(x1, x2)) def logical_not(x): if isinstance(x, lazy.LazyTensor): return lazy.logical_not(x) - return compute(lazy.logical_not(x)) + return compute(lazy.logical_not(x), ctx=get_eager_scheduler(x)) def mean(x, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False): if isinstance(x, lazy.LazyTensor): return lazy.mean(x, axis=axis, keepdims=keepdims) - return compute(lazy.mean(x, axis=axis, keepdims=keepdims)) + return compute(lazy.mean(x, axis=axis, keepdims=keepdims), ctx=get_eager_scheduler(x)) def var( @@ -817,7 +826,7 @@ def var( ): if isinstance(x, lazy.LazyTensor): return lazy.var(x, axis=axis, correction=correction, keepdims=keepdims) - return compute(lazy.var(x, axis=axis, correction=correction, keepdims=keepdims)) + return compute(lazy.var(x, axis=axis, correction=correction, keepdims=keepdims), ctx=get_eager_scheduler(x)) def std( @@ -830,4 +839,4 @@ def std( ): if isinstance(x, lazy.LazyTensor): return lazy.std(x, axis=axis, correction=correction, keepdims=keepdims) - return compute(lazy.std(x, axis=axis, correction=correction, keepdims=keepdims)) + return compute(lazy.std(x, axis=axis, correction=correction, keepdims=keepdims), ctx=get_eager_scheduler(x)) diff --git a/src/finchlite/interface/fuse.py b/src/finchlite/interface/fuse.py index 3e519066..d8cbdd77 100644 --- a/src/finchlite/interface/fuse.py +++ b/src/finchlite/interface/fuse.py @@ -51,7 +51,6 @@ """ from ..autoschedule import DefaultLogicOptimizer, LogicCompiler -from ..autoschedule.einsum import EinsumCompiler, EinsumInterpreter from ..finch_logic import Alias, FinchLogicInterpreter, Plan, Produces, Query from ..finch_notation import NotationInterpreter from ..symbolic import gensym @@ -60,22 +59,13 @@ _DEFAULT_SCHEDULER = None -def set_default_scheduler(*, ctx=None, interpret_logic=False, interpret_einsum=False): +def set_default_scheduler(*, ctx=None, interpret_logic=False): global _DEFAULT_SCHEDULER if ctx is not None: _DEFAULT_SCHEDULER = ctx elif interpret_logic: _DEFAULT_SCHEDULER = FinchLogicInterpreter() - elif interpret_einsum: - optimizer = DefaultLogicOptimizer(EinsumCompiler()) - einsum_interpreter = EinsumInterpreter() - - def fn_compile(plan): - einsums, parameters, _ = optimizer(plan) - return einsum_interpreter(einsums, parameters) - - _DEFAULT_SCHEDULER = fn_compile else: optimizer = DefaultLogicOptimizer(LogicCompiler()) ntn_interp = NotationInterpreter() From 1fd907b26352a9763b39060a442aed36e8b57c96 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Tue, 23 Sep 2025 11:31:58 -0400 Subject: [PATCH 20/26] * Added support for printing sparse tensors --- src/finchlite/autoschedule/sparse_tensor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/finchlite/autoschedule/sparse_tensor.py b/src/finchlite/autoschedule/sparse_tensor.py index 8d86228f..9e4b914b 100644 --- a/src/finchlite/autoschedule/sparse_tensor.py +++ b/src/finchlite/autoschedule/sparse_tensor.py @@ -81,6 +81,9 @@ def __getitem__(self, idx: tuple): return self.data[matching_indices[0]] return 0 + def __str__(self): + return f"SparseTensor(data={self.data}, coords={self.coords}, shape={self.shape}, element_type={self._element_type})" + def to_dense(self) -> np.ndarray: dense_tensor = np.zeros(self.shape, dtype=self._element_type) for i in range(self.coords.shape[0]): From e78c733db9610ca1d31a07d5331f9e4e236018ec Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Sun, 28 Sep 2025 22:48:20 -0400 Subject: [PATCH 21/26] * Reverted changes to eager.py --- src/finchlite/autoschedule/einsum.py | 56 ++++++++--- src/finchlite/interface/eager.py | 141 ++++++++++++--------------- 2 files changed, 110 insertions(+), 87 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index b96eecda..215a53e0 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -3,6 +3,7 @@ import operator from typing import Callable, Self +from finchlite.algebra.tensor import Tensor from finchlite.finch_logic import LogicNode, Field, Plan, Query, Alias, Literal, Relabel from finchlite.finch_logic.nodes import Aggregate, MapJoin, Produces, Reorder, Table from finchlite.symbolic import Term, TermTree @@ -73,6 +74,7 @@ class PointwiseOp(PointwiseNode): op: Callable #the function to apply e.g., operator.add args: tuple[PointwiseNode, ...] # Subtrees + input_fields: tuple[tuple[Field, ...], ...] # Children: The args @classmethod @@ -118,7 +120,7 @@ class Einsum(TermTree): reduceOp: Callable #technically a reduce operation, much akin to the one in aggregate - input_fields: tuple[Field, ...] + input_fields: tuple[Field, ...] #redundant remove later output_fields: tuple[Field, ...] pointwise_expr: PointwiseNode output_alias: str | None @@ -351,19 +353,51 @@ def print_einsum_plan(self, einsum_plan: EinsumPlan): def __call__(self, prgm: EinsumPlan): return self.print_einsum_plan(prgm) +class EinsumCompiler: + def __call__(self, einsum_plan: EinsumPlan, parameters: dict[str, Table]): + return self.print(einsum_plan, parameters) + + def print(self, einsum_plan: EinsumPlan, parameters: dict[str, Table]): + for (str, table) in parameters.items(): + print(f"Parameter: {str} = {table}") + + print(einsum_plan) + return (np.arange(6, dtype=np.float32).reshape(2, 3),) + + def pointwise_to_numpy(self, pointwise_expr: PointwiseNode, alias_values: dict[str, Tensor]) -> Tensor: + match pointwise_expr: + case PointwiseAccess(alias, idxs): + return alias_values[alias][idxs] + case PointwiseOp(op, args): + match op: + case operator + case _: + raise NotImplementedError(f"Operation {op} not implemented") + case PointwiseLiteral(val): + return val + raise NotImplementedError(f"Pointwise expression {pointwise_expr} not implemented") + + def einsum_to_numpy(self, einsum: Einsum, alias_values: dict[str, Tensor]) -> Tensor: + pass + + def plan_to_numpy(self, einsum_plan: EinsumPlan, parameters: dict[str, Table]) -> tuple[Tensor, ...]: + alias_values = dict() + for (str, table) in parameters.items(): + alias_values[str] = table.tns + + for einsum in einsum_plan.bodies: + alias_values[einsum.output_alias] = self.einsum_to_numpy(einsum, alias_values) + + return [ + (alias_values[return_value] if isinstance(return_value, str) else self.einsum_to_numpy(return_value, alias_values)) + for return_value in einsum_plan.returnValues + ] + class EinsumScheduler: def __init__(self, ctx: EinsumCompiler): self.ctx = ctx + self.interpret = EinsumCompiler() def __call__(self, prgm: LogicNode): einsum_plan, parameters, _ = self.ctx(prgm) - return self.interpret(einsum_plan, parameters) - - def interpret(self, einsum_plan: EinsumPlan, parameters: dict[str, Table]): - import numpy as np - for (str, table) in parameters.items(): - print(f"Parameter: {str} = {table}") - - print(einsum_plan) - - return (np.arange(6, dtype=np.float32).reshape(2, 3),) \ No newline at end of file + return self.interpret(einsum_plan, parameters) \ No newline at end of file diff --git a/src/finchlite/interface/eager.py b/src/finchlite/interface/eager.py index 9540d3be..720e5949 100644 --- a/src/finchlite/interface/eager.py +++ b/src/finchlite/interface/eager.py @@ -225,21 +225,10 @@ def __ne__(self, other): register_property(EagerTensor, "asarray", "__attr__", lambda x: x) -def get_eager_scheduler(*args): - from finchlite.autoschedule.sparse_tensor import SparseTensor - from finchlite.autoschedule import EinsumScheduler, DefaultLogicOptimizer, EinsumCompiler - - for arg in args: - if isinstance(arg, SparseTensor): - return DefaultLogicOptimizer(EinsumScheduler(EinsumCompiler())) - - return None - - def permute_dims(arg, /, axis: tuple[int, ...]): if isinstance(arg, lazy.LazyTensor): return lazy.permute_dims(arg, axis=axis) - return compute(lazy.permute_dims(arg, axis=axis), ctx=get_eager_scheduler(arg)) + return compute(lazy.permute_dims(arg, axis=axis)) def expand_dims( @@ -249,7 +238,7 @@ def expand_dims( ): if isinstance(x, lazy.LazyTensor): return lazy.expand_dims(x, axis=axis) - return compute(lazy.expand_dims(x, axis=axis), ctx=get_eager_scheduler(x)) + return compute(lazy.expand_dims(x, axis=axis)) def squeeze( @@ -259,7 +248,7 @@ def squeeze( ): if isinstance(x, lazy.LazyTensor): return lazy.squeeze(x, axis=axis) - return compute(lazy.squeeze(x, axis=axis), ctx=get_eager_scheduler(x)) + return compute(lazy.squeeze(x, axis=axis)) def reduce( @@ -276,7 +265,7 @@ def reduce( return lazy.reduce(op, x, axis=axis, dtype=dtype, keepdims=keepdims, init=init) return compute( lazy.reduce(op, x, axis=axis, dtype=dtype, keepdims=keepdims, init=init), - ctx=get_eager_scheduler(x) + ctx=get_eager_scheduler(x) ) @@ -290,7 +279,7 @@ def sum( ): if isinstance(x, lazy.LazyTensor): return lazy.sum(x, axis=axis, dtype=dtype, keepdims=keepdims) - return compute(lazy.sum(x, axis=axis, dtype=dtype, keepdims=keepdims), ctx=get_eager_scheduler(x)) + return compute(lazy.sum(x, axis=axis, dtype=dtype, keepdims=keepdims)) def prod( @@ -298,53 +287,53 @@ def prod( /, *, axis: int | tuple[int, ...] | None = None, - dtype=None, + dtype=None, keepdims: bool = False, ): if isinstance(x, lazy.LazyTensor): return lazy.prod(x, axis=axis, dtype=dtype, keepdims=keepdims) - return compute(lazy.prod(x, axis=axis, dtype=dtype, keepdims=keepdims), ctx=get_eager_scheduler(x)) + return compute(lazy.prod(x, axis=axis, dtype=dtype, keepdims=keepdims)) def elementwise(f: Callable, *args): if builtins.any(isinstance(arg, lazy.LazyTensor) for arg in args): return lazy.elementwise(f, *args) - return compute(lazy.elementwise(f, *args), ctx=get_eager_scheduler(*args)) + return compute(lazy.elementwise(f, *args)) def add(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.add(x1, x2) - return compute(lazy.add(x1, x2), ctx=get_eager_scheduler(x1, x2)) + return compute(lazy.add(x1, x2)) def subtract(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.subtract(x1, x2) - return compute(lazy.subtract(x1, x2), ctx=get_eager_scheduler(x1, x2)) + return compute(lazy.subtract(x1, x2)) def multiply(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.multiply(x1, x2) - return compute(lazy.multiply(x1, x2), ctx=get_eager_scheduler(x1, x2)) + return compute(lazy.multiply(x1, x2)) def abs(x): if isinstance(x, lazy.LazyTensor): return lazy.abs(x) - return compute(lazy.abs(x), ctx=get_eager_scheduler(x)) + return compute(lazy.abs(x)) def positive(x): if isinstance(x, lazy.LazyTensor): return lazy.positive(x) - return compute(lazy.positive(x), ctx=get_eager_scheduler(x)) + return compute(lazy.positive(x)) def negative(x): if isinstance(x, lazy.LazyTensor): return lazy.negative(x) - return compute(lazy.negative(x), ctx=get_eager_scheduler(x)) + return compute(lazy.negative(x)) def matmul(x1, x2, /): @@ -357,7 +346,7 @@ def matmul(x1, x2, /): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.matmul(x1, x2) c = lazy.matmul(x1, x2) - return compute(c, ctx=get_eager_scheduler(x1, x2)) + return compute(c) def matrix_transpose(x, /): @@ -366,67 +355,67 @@ def matrix_transpose(x, /): """ if isinstance(x, lazy.LazyTensor): return lazy.matrix_transpose(x) - return compute(lazy.matrix_transpose(x), ctx=get_eager_scheduler(x)) + return compute(lazy.matrix_transpose(x)) def bitwise_inverse(x): if isinstance(x, lazy.LazyTensor): return lazy.bitwise_inverse(x) - return compute(lazy.bitwise_inverse(x), ctx=get_eager_scheduler(x)) + return compute(lazy.bitwise_inverse(x)) def bitwise_and(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.bitwise_and(x1, x2) - return compute(lazy.bitwise_and(x1, x2), ctx=get_eager_scheduler(x1, x2)) + return compute(lazy.bitwise_and(x1, x2)) def bitwise_left_shift(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.bitwise_left_shift(x1, x2) - return compute(lazy.bitwise_left_shift(x1, x2), ctx=get_eager_scheduler(x1, x2)) + return compute(lazy.bitwise_left_shift(x1, x2)) def bitwise_or(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.bitwise_or(x1, x2) - return compute(lazy.bitwise_or(x1, x2), ctx=get_eager_scheduler(x1, x2)) + return compute(lazy.bitwise_or(x1, x2)) def bitwise_right_shift(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.bitwise_right_shift(x1, x2) - return compute(lazy.bitwise_right_shift(x1, x2), ctx=get_eager_scheduler(x1, x2)) + return compute(lazy.bitwise_right_shift(x1, x2)) def bitwise_xor(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.bitwise_xor(x1, x2) - return compute(lazy.bitwise_xor(x1, x2), ctx=get_eager_scheduler(x1, x2)) + return compute(lazy.bitwise_xor(x1, x2)) def truediv(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.truediv(x1, x2) - return compute(lazy.truediv(x1, x2), ctx=get_eager_scheduler(x1, x2)) + return compute(lazy.truediv(x1, x2)) def floordiv(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.floordiv(x1, x2) - return compute(lazy.floordiv(x1, x2), ctx=get_eager_scheduler(x1, x2)) + return compute(lazy.floordiv(x1, x2)) def mod(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.mod(x1, x2) - return compute(lazy.mod(x1, x2), ctx=get_eager_scheduler(x1, x2)) + return compute(lazy.mod(x1, x2)) def pow(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.pow(x1, x2) - return compute(lazy.pow(x1, x2), ctx=get_eager_scheduler(x1, x2)) + return compute(lazy.pow(x1, x2)) def tensordot(x1, x2, /, *, axes: int | tuple[Sequence[int], Sequence[int]]): @@ -438,7 +427,7 @@ def tensordot(x1, x2, /, *, axes: int | tuple[Sequence[int], Sequence[int]]): """ if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.tensordot(x1, x2, axes=axes) - return compute(lazy.tensordot(x1, x2, axes=axes), ctx=get_eager_scheduler(x1, x2)) + return compute(lazy.tensordot(x1, x2, axes=axes)) def vecdot(x1, x2, /, *, axis=-1): @@ -461,31 +450,31 @@ def vecdot(x1, x2, /, *, axis=-1): """ if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.vecdot(x1, x2, axis=axis) - return compute(lazy.vecdot(x1, x2, axis=axis), ctx=get_eager_scheduler(x1, x2)) + return compute(lazy.vecdot(x1, x2, axis=axis)) def any(x, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False): if isinstance(x, lazy.LazyTensor): return lazy.any(x, axis=axis, keepdims=keepdims) - return compute(lazy.any(x, axis=axis, keepdims=keepdims), ctx=get_eager_scheduler(x)) + return compute(lazy.any(x, axis=axis, keepdims=keepdims)) def all(x, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False): if isinstance(x, lazy.LazyTensor): return lazy.all(x, axis=axis, keepdims=keepdims) - return compute(lazy.all(x, axis=axis, keepdims=keepdims), ctx=get_eager_scheduler(x)) + return compute(lazy.all(x, axis=axis, keepdims=keepdims)) def min(x, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False): if isinstance(x, lazy.LazyTensor): return lazy.min(x, axis=axis, keepdims=keepdims) - return compute(lazy.min(x, axis=axis, keepdims=keepdims), ctx=get_eager_scheduler(x)) + return compute(lazy.min(x, axis=axis, keepdims=keepdims)) def max(x, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False): if isinstance(x, lazy.LazyTensor): return lazy.max(x, axis=axis, keepdims=keepdims) - return compute(lazy.max(x, axis=axis, keepdims=keepdims), ctx=get_eager_scheduler(x)) + return compute(lazy.max(x, axis=axis, keepdims=keepdims)) # manipulation functions: @@ -511,7 +500,7 @@ def broadcast_to(x, /, shape: Sequence[int]): shape = tuple(shape) # Ensure shape is a tuple for consistency if isinstance(x, lazy.LazyTensor): return lazy.broadcast_to(x, shape=shape) - return compute(lazy.broadcast_to(x, shape=shape), ctx=get_eager_scheduler(x)) + return compute(lazy.broadcast_to(x, shape=shape)) def broadcast_arrays(*args): @@ -532,7 +521,7 @@ def broadcast_arrays(*args): if builtins.any(isinstance(arg, lazy.LazyTensor) for arg in args): return lazy.broadcast_arrays(*args) # compute can take in a list of LazyTensors - return compute(lazy.broadcast_arrays(*args), ctx=get_eager_scheduler(*args)) + return compute(lazy.broadcast_arrays(*args)) def concat(arrays: tuple | list, /, *, axis: int | None = 0): @@ -555,7 +544,7 @@ def concat(arrays: tuple | list, /, *, axis: int | None = 0): """ if builtins.any(isinstance(arr, lazy.LazyTensor) for arr in arrays): return lazy.concat(arrays, axis=axis) - return compute(lazy.concat(arrays, axis=axis), ctx=get_eager_scheduler(*arrays)) + return compute(lazy.concat(arrays, axis=axis)) def moveaxis(x, source: int | tuple[int, ...], destination: int | tuple[int, ...], /): @@ -576,7 +565,7 @@ def moveaxis(x, source: int | tuple[int, ...], destination: int | tuple[int, ... """ if isinstance(x, lazy.LazyTensor): return lazy.moveaxis(x, source, destination) - return compute(lazy.moveaxis(x, source, destination), ctx=get_eager_scheduler(x)) + return compute(lazy.moveaxis(x, source, destination)) def stack(arrays: Sequence, /, *, axis: int = 0): @@ -597,7 +586,7 @@ def stack(arrays: Sequence, /, *, axis: int = 0): """ if builtins.any(isinstance(arr, lazy.LazyTensor) for arr in arrays): return lazy.stack(arrays, axis=axis) - return compute(lazy.stack(arrays, axis=axis), ctx=get_eager_scheduler(*arrays)) + return compute(lazy.stack(arrays, axis=axis)) def split_dims(x, axis: int, shape: tuple): @@ -630,7 +619,7 @@ def split_dims(x, axis: int, shape: tuple): """ if isinstance(x, lazy.LazyTensor): return lazy.split_dims(x, axis, shape) - return compute(lazy.split_dims(x, axis, shape), ctx=get_eager_scheduler(x)) + return compute(lazy.split_dims(x, axis, shape)) def combine_dims(x, axes: tuple[int, ...]): @@ -665,7 +654,7 @@ def combine_dims(x, axes: tuple[int, ...]): """ if isinstance(x, lazy.LazyTensor): return lazy.combine_dims(x, axes) - return compute(lazy.combine_dims(x, axes), ctx=get_eager_scheduler(x)) + return compute(lazy.combine_dims(x, axes)) def flatten(x): @@ -692,140 +681,140 @@ def flatten(x): """ if isinstance(x, lazy.LazyTensor): return lazy.flatten(x) - return compute(lazy.flatten(x), ctx=get_eager_scheduler(x)) + return compute(lazy.flatten(x)) # trigonometric functions: def sin(x): if isinstance(x, lazy.LazyTensor): return lazy.sin(x) - return compute(lazy.sin(x), ctx=get_eager_scheduler(x)) + return compute(lazy.sin(x)) def sinh(x): if isinstance(x, lazy.LazyTensor): return lazy.sinh(x) - return compute(lazy.sinh(x), ctx=get_eager_scheduler(x)) + return compute(lazy.sinh(x)) def cos(x): if isinstance(x, lazy.LazyTensor): return lazy.cos(x) - return compute(lazy.cos(x), ctx=get_eager_scheduler(x)) + return compute(lazy.cos(x)) def cosh(x): if isinstance(x, lazy.LazyTensor): return lazy.cosh(x) - return compute(lazy.cosh(x), ctx=get_eager_scheduler(x)) + return compute(lazy.cosh(x)) def tan(x): if isinstance(x, lazy.LazyTensor): return lazy.tan(x) - return compute(lazy.tan(x), ctx=get_eager_scheduler(x)) + return compute(lazy.tan(x)) def tanh(x): if isinstance(x, lazy.LazyTensor): return lazy.tanh(x) - return compute(lazy.tanh(x), ctx=get_eager_scheduler(x)) + return compute(lazy.tanh(x)) def asin(x): if isinstance(x, lazy.LazyTensor): return lazy.asin(x) - return compute(lazy.asin(x), ctx=get_eager_scheduler(x)) + return compute(lazy.asin(x)) def asinh(x): if isinstance(x, lazy.LazyTensor): return lazy.asinh(x) - return compute(lazy.asinh(x), ctx=get_eager_scheduler(x)) + return compute(lazy.asinh(x)) def acos(x): if isinstance(x, lazy.LazyTensor): return lazy.acos(x) - return compute(lazy.acos(x), ctx=get_eager_scheduler(x)) + return compute(lazy.acos(x)) def acosh(x): if isinstance(x, lazy.LazyTensor): return lazy.acosh(x) - return compute(lazy.acosh(x), ctx=get_eager_scheduler(x)) + return compute(lazy.acosh(x)) def atan(x): if isinstance(x, lazy.LazyTensor): return lazy.atan(x) - return compute(lazy.atan(x), ctx=get_eager_scheduler(x)) + return compute(lazy.atan(x)) def atanh(x): if isinstance(x, lazy.LazyTensor): return lazy.atanh(x) - return compute(lazy.atanh(x), ctx=get_eager_scheduler(x)) + return compute(lazy.atanh(x)) def atan2(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.atan2(x1, x2) - return compute(lazy.atan2(x1, x2), ctx=get_eager_scheduler(x1, x2)) + return compute(lazy.atan2(x1, x2)) def log(x): if isinstance(x, lazy.LazyTensor): return lazy.log(x) - return compute(lazy.log(x), ctx=get_eager_scheduler(x)) + return compute(lazy.log(x)) def log1p(x): if isinstance(x, lazy.LazyTensor): return lazy.log1p(x) - return compute(lazy.log1p(x), ctx=get_eager_scheduler(x)) + return compute(lazy.log1p(x)) def log2(x): if isinstance(x, lazy.LazyTensor): return lazy.log2(x) - return compute(lazy.log2(x), ctx=get_eager_scheduler(x)) + return compute(lazy.log2(x)) def log10(x): if isinstance(x, lazy.LazyTensor): return lazy.log10(x) - return compute(lazy.log10(x), ctx=get_eager_scheduler(x)) + return compute(lazy.log10(x)) def logaddexp(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.logaddexp(x1, x2) - return compute(lazy.logaddexp(x1, x2), ctx=get_eager_scheduler(x1, x2)) + return compute(lazy.logaddexp(x1, x2)) def logical_and(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.logical_and(x1, x2) - return compute(lazy.logical_and(x1, x2), ctx=get_eager_scheduler(x1, x2)) + return compute(lazy.logical_and(x1, x2)) def logical_or(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.logical_or(x1, x2) - return compute(lazy.logical_or(x1, x2), ctx=get_eager_scheduler(x1, x2)) + return compute(lazy.logical_or(x1, x2)) def logical_xor(x1, x2): if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.logical_xor(x1, x2) - return compute(lazy.logical_xor(x1, x2), ctx=get_eager_scheduler(x1, x2)) + return compute(lazy.logical_xor(x1, x2)) def logical_not(x): if isinstance(x, lazy.LazyTensor): return lazy.logical_not(x) - return compute(lazy.logical_not(x), ctx=get_eager_scheduler(x)) + return compute(lazy.logical_not(x)) def less(x1, x2): @@ -867,7 +856,7 @@ def not_equal(x1, x2): def mean(x, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False): if isinstance(x, lazy.LazyTensor): return lazy.mean(x, axis=axis, keepdims=keepdims) - return compute(lazy.mean(x, axis=axis, keepdims=keepdims), ctx=get_eager_scheduler(x)) + return compute(lazy.mean(x, axis=axis, keepdims=keepdims)) def var( @@ -880,7 +869,7 @@ def var( ): if isinstance(x, lazy.LazyTensor): return lazy.var(x, axis=axis, correction=correction, keepdims=keepdims) - return compute(lazy.var(x, axis=axis, correction=correction, keepdims=keepdims), ctx=get_eager_scheduler(x)) + return compute(lazy.var(x, axis=axis, correction=correction, keepdims=keepdims)) def std( @@ -893,4 +882,4 @@ def std( ): if isinstance(x, lazy.LazyTensor): return lazy.std(x, axis=axis, correction=correction, keepdims=keepdims) - return compute(lazy.std(x, axis=axis, correction=correction, keepdims=keepdims), ctx=get_eager_scheduler(x)) + return compute(lazy.std(x, axis=axis, correction=correction, keepdims=keepdims)) From b52e748860787a93351571811870c88e25355434 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Sun, 28 Sep 2025 22:56:29 -0400 Subject: [PATCH 22/26] * Added indrect COO access pointwise operation --- src/finchlite/autoschedule/einsum.py | 58 ++++++++++++++-------------- src/finchlite/interface/eager.py | 4 +- 2 files changed, 30 insertions(+), 32 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 215a53e0..ff4b5f8b 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -59,6 +59,33 @@ def from_children(cls, alias: str, idxs: tuple[Field, ...]) -> Self: def children(self): return [self.alias, *self.idxs] + +@dataclass(eq=True, frozen=True) +class PointwiseIndirectCOOAccess(PointwiseNode, TermTree): + """ + PointwiseIndirectCOOAccess + + Tensor access like a[i, j] but for sparse tensors. So in reality it's like a[COO_coords[i]] = ... + + Attributes: + tensor: The tensor to access. + coo_coords: The COO coordinates at which to access the tensor (this is also a tensor). + idxs: The indices at which to access the tensor. + """ + + alias: str + coo_coord_alias: str + idx: Field #only one index is needed to access the COO coord tensor + # Children: None (leaf) + + @classmethod + def from_children(cls, alias: str, coo_coord_alias: str, idx: Field) -> Self: + return cls(alias, coo_coord_alias, idx) + + @property + def children(self): + return [self.alias, self.coo_coord_alias, self.idx] + @dataclass(eq=True, frozen=True) class PointwiseOp(PointwiseNode): """ @@ -334,6 +361,8 @@ def print_pointwise_expr(self, pointwise_expr: PointwiseNode): match pointwise_expr: case PointwiseAccess(alias, idxs): return f"{alias}[{self.print_indicies(idxs)}]" + case PointwiseIndirectCOOAccess(alias, coo_coord_alias, idx): + return f"{alias}[{coo_coord_alias}[{self.print_indicies((idx, ))}]]" case PointwiseOp(_, __): return self.print_pointwise_op(pointwise_expr) case PointwiseLiteral(val): @@ -364,35 +393,6 @@ def print(self, einsum_plan: EinsumPlan, parameters: dict[str, Table]): print(einsum_plan) return (np.arange(6, dtype=np.float32).reshape(2, 3),) - def pointwise_to_numpy(self, pointwise_expr: PointwiseNode, alias_values: dict[str, Tensor]) -> Tensor: - match pointwise_expr: - case PointwiseAccess(alias, idxs): - return alias_values[alias][idxs] - case PointwiseOp(op, args): - match op: - case operator - case _: - raise NotImplementedError(f"Operation {op} not implemented") - case PointwiseLiteral(val): - return val - raise NotImplementedError(f"Pointwise expression {pointwise_expr} not implemented") - - def einsum_to_numpy(self, einsum: Einsum, alias_values: dict[str, Tensor]) -> Tensor: - pass - - def plan_to_numpy(self, einsum_plan: EinsumPlan, parameters: dict[str, Table]) -> tuple[Tensor, ...]: - alias_values = dict() - for (str, table) in parameters.items(): - alias_values[str] = table.tns - - for einsum in einsum_plan.bodies: - alias_values[einsum.output_alias] = self.einsum_to_numpy(einsum, alias_values) - - return [ - (alias_values[return_value] if isinstance(return_value, str) else self.einsum_to_numpy(return_value, alias_values)) - for return_value in einsum_plan.returnValues - ] - class EinsumScheduler: def __init__(self, ctx: EinsumCompiler): self.ctx = ctx diff --git a/src/finchlite/interface/eager.py b/src/finchlite/interface/eager.py index 720e5949..b972316c 100644 --- a/src/finchlite/interface/eager.py +++ b/src/finchlite/interface/eager.py @@ -264,9 +264,7 @@ def reduce( if isinstance(x, lazy.LazyTensor): return lazy.reduce(op, x, axis=axis, dtype=dtype, keepdims=keepdims, init=init) return compute( - lazy.reduce(op, x, axis=axis, dtype=dtype, keepdims=keepdims, init=init), - ctx=get_eager_scheduler(x) - ) + lazy.reduce(op, x, axis=axis, dtype=dtype, keepdims=keepdims, init=init)) def sum( From 3ff8c38d6ca8aecbb4bbc0e3c351e3d031847021 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 29 Sep 2025 17:16:55 -0400 Subject: [PATCH 23/26] * Fixed issues with printing einsum interpreter --- src/finchlite/autoschedule/einsum.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index ff4b5f8b..6709a2f6 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -101,7 +101,7 @@ class PointwiseOp(PointwiseNode): op: Callable #the function to apply e.g., operator.add args: tuple[PointwiseNode, ...] # Subtrees - input_fields: tuple[tuple[Field, ...], ...] + #input_fields: tuple[tuple[Field, ...], ...] # Children: The args @classmethod @@ -382,7 +382,7 @@ def print_einsum_plan(self, einsum_plan: EinsumPlan): def __call__(self, prgm: EinsumPlan): return self.print_einsum_plan(prgm) -class EinsumCompiler: +class EinsumInterpreter: def __call__(self, einsum_plan: EinsumPlan, parameters: dict[str, Table]): return self.print(einsum_plan, parameters) @@ -396,7 +396,7 @@ def print(self, einsum_plan: EinsumPlan, parameters: dict[str, Table]): class EinsumScheduler: def __init__(self, ctx: EinsumCompiler): self.ctx = ctx - self.interpret = EinsumCompiler() + self.interpret = EinsumInterpreter() def __call__(self, prgm: LogicNode): einsum_plan, parameters, _ = self.ctx(prgm) From 054655f1bbb19541850a142d15b260d74a058ad9 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 29 Sep 2025 17:38:30 -0400 Subject: [PATCH 24/26] * Removed redundant input fields property in einsum IR node --- src/finchlite/autoschedule/einsum.py | 37 ++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 6709a2f6..0fa14478 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -147,24 +147,28 @@ class Einsum(TermTree): reduceOp: Callable #technically a reduce operation, much akin to the one in aggregate - input_fields: tuple[Field, ...] #redundant remove later + #input_fields: tuple[Field, ...] #redundant remove later output_fields: tuple[Field, ...] pointwise_expr: PointwiseNode output_alias: str | None @classmethod def from_children(cls, output_alias: str | None, updateOp: Callable, input_fields: tuple[Field, ...], output_fields: tuple[Field, ...], pointwise_expr: PointwiseNode) -> Self: - return cls(output_alias, updateOp, input_fields, output_fields, pointwise_expr) + #return cls(output_alias, updateOp, input_fields, output_fields, pointwise_expr) + return cls(output_alias, updateOp, output_fields, pointwise_expr) @property def children(self): - return [self.output_alias, self.reduceOp, self.input_fields, self.output_fields, self.pointwise_expr] + #return [self.output_alias, self.reduceOp, self.input_fields, self.output_fields, self.pointwise_expr] + return [self.output_alias, self.reduceOp, self.output_fields, self.pointwise_expr] def rename(self, new_alias: str): - return Einsum(self.reduceOp, self.input_fields, self.output_fields, self.pointwise_expr, new_alias) + #return Einsum(self.reduceOp, self.input_fields, self.output_fields, self.pointwise_expr, new_alias) + return Einsum(self.reduceOp, self.output_fields, self.pointwise_expr, new_alias) def reorder(self, idxs: tuple[Field, ...]): - return Einsum(self.reduceOp, idxs, self.output_fields, self.pointwise_expr, self.output_alias) + #return Einsum(self.reduceOp, idxs, self.output_fields, self.pointwise_expr, self.output_alias) + return Einsum(self.reduceOp, self.output_fields, self.pointwise_expr, self.output_alias) def __str__(self): ctx = EinsumPrinterContext() @@ -256,14 +260,16 @@ def lower_to_einsum(self, ex: LogicNode, einsums: list[Einsum], parameters: dict case MapJoin(Literal(operation), args): args = [self.lower_to_pointwise(arg, einsums, parameters, definitions) for arg in args] pointwise_expr = self.lower_to_pointwise_op(operation, args) - return Einsum(reduceOp=overwrite, input_fields=ex.fields, output_fields=ex.fields, pointwise_expr=pointwise_expr, output_alias=None) + #return Einsum(reduceOp=overwrite, input_fields=ex.fields, output_fields=ex.fields, pointwise_expr=pointwise_expr, output_alias=None) + return Einsum(reduceOp=overwrite, output_fields=ex.fields, pointwise_expr=pointwise_expr, output_alias=None) case Reorder(arg, idxs): return self.lower_to_einsum(arg, einsums, parameters, definitions).reorder(idxs) case Aggregate(Literal(operation), Literal(init), arg, idxs): if init != init_value(operation, type(init)): raise Exception(f"Init value {init} is not the default value for operation {operation} of type {type(init)}. Non standard init values are not supported.") pointwise_expr = self.lower_to_pointwise(arg, einsums, parameters, definitions) - return Einsum(operation, arg.fields, ex.fields, pointwise_expr, self.get_next_alias()) + #return Einsum(operation, arg.fields, ex.fields, pointwise_expr, self.get_next_alias()) + return Einsum(operation, ex.fields, pointwise_expr, self.get_next_alias()) case _: raise Exception(f"Unrecognized logic: {ex}") @@ -310,11 +316,26 @@ class EinsumCompiler: def __init__(self): self.el = EinsumLowerer() + def optimize_einsum(self, einsum_plan: EinsumPlan) -> EinsumPlan: + def optimize_sparse_einsum(einsum: Einsum, extra_ops: list[Einsum]) -> Einsum: + #match einsum: + # case Einsum(reduceOp=add, pointwise) + + return einsum + + optimized_einsums = [] + for einsum in einsum_plan.bodies: + optimized_einsums.append(optimize_sparse_einsum(einsum, optimized_einsums)) + + optimized_return = optimize_sparse_einsum(einsum_plan.returnValues[0], optimized_einsums) + return EinsumPlan(tuple(optimize_sparse_einsum), optimized_return) + def __call__(self, prgm: Plan): parameters = {} definitions = {} einsum_plan = self.el(prgm, parameters, definitions) - + einsum_plan = self.optimize_einsum(einsum_plan) + return einsum_plan, parameters, definitions class EinsumPrinterContext: From 85d348428198798087c15ca3278906964ea91848 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 29 Sep 2025 18:58:44 -0400 Subject: [PATCH 25/26] * Added abstract statement IR node for einsum plans * Removed redunant output fields from Einsum * Added extract COO einsum IR statement --- src/finchlite/autoschedule/einsum.py | 157 +++++++++++++++++++-------- 1 file changed, 114 insertions(+), 43 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 0fa14478..9eb9c38c 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -1,6 +1,8 @@ +from ast import Tuple from dataclasses import dataclass from abc import ABC import operator +from turtle import st from typing import Callable, Self from finchlite.algebra.tensor import Tensor @@ -128,11 +130,35 @@ def __hash__(self): def __eq__(self, other): return isinstance(other, PointwiseLiteral) and self.val == other.val +@dataclass(eq=True, frozen=True) +class EinsumPlanStatement(Term, ABC): + """ + EinsumPlanStatement + + Represents an AST node in the Einsum Plan IR + """ + + @classmethod + def head(cls): + """Returns the head of the node.""" + return cls + + @classmethod + def make_term(cls, head, *children: Term) -> Self: + return head.from_children(*children) + + @classmethod + def from_children(cls, *children: Term) -> Self: + return cls(*children) + + def __str__(self): + ctx = EinsumPrinterContext() + return ctx.print_einsum_plan_statement(self) #einsum and einsum ast not part of logic IR #transform to it's own language @dataclass(eq=True, frozen=True) -class Einsum(TermTree): +class Einsum(EinsumPlanStatement, TermTree): """ Einsum @@ -150,29 +176,44 @@ class Einsum(TermTree): #input_fields: tuple[Field, ...] #redundant remove later output_fields: tuple[Field, ...] pointwise_expr: PointwiseNode + output_alias: str | None + indirect_coo_alias: str | None @classmethod - def from_children(cls, output_alias: str | None, updateOp: Callable, input_fields: tuple[Field, ...], output_fields: tuple[Field, ...], pointwise_expr: PointwiseNode) -> Self: + def from_children(cls, output_alias: str | None, updateOp: Callable, output_fields: tuple[Field, ...], pointwise_expr: PointwiseNode, indirect_coo_alias: str | None) -> Self: #return cls(output_alias, updateOp, input_fields, output_fields, pointwise_expr) - return cls(output_alias, updateOp, output_fields, pointwise_expr) + return cls(output_alias, updateOp, output_fields, pointwise_expr, indirect_coo_alias) @property def children(self): #return [self.output_alias, self.reduceOp, self.input_fields, self.output_fields, self.pointwise_expr] - return [self.output_alias, self.reduceOp, self.output_fields, self.pointwise_expr] + return [self.output_alias, self.reduceOp, self.output_fields, self.pointwise_expr, self.indirect_coo_alias] def rename(self, new_alias: str): #return Einsum(self.reduceOp, self.input_fields, self.output_fields, self.pointwise_expr, new_alias) - return Einsum(self.reduceOp, self.output_fields, self.pointwise_expr, new_alias) + return Einsum(self.reduceOp, self.output_fields, self.pointwise_expr, new_alias, self.indirect_coo_alias) def reorder(self, idxs: tuple[Field, ...]): #return Einsum(self.reduceOp, idxs, self.output_fields, self.pointwise_expr, self.output_alias) - return Einsum(self.reduceOp, self.output_fields, self.pointwise_expr, self.output_alias) + return Einsum(self.reduceOp, idxs, self.pointwise_expr, self.output_alias, self.indirect_coo_alias) - def __str__(self): - ctx = EinsumPrinterContext() - return ctx.print_einsum(self) +@dataclass(eq=True, frozen=True) +class ExtractCOO(EinsumPlanStatement): + """ + ExtractCOO + + A plan statement that contains an extract's the COO matrix from a sparse tensor. + """ + alias: str + + @classmethod + def from_children(cls, alias: str) -> Self: + return cls(alias) + + @property + def children(self): + return [self.alias] @dataclass(eq=True, frozen=True) class EinsumPlan(Plan): @@ -182,11 +223,11 @@ class EinsumPlan(Plan): A plan that contains einsum operations. Basically a list of einsum operations. """ - bodies: tuple[Einsum, ...] = () + bodies: tuple[EinsumPlanStatement, ...] = () returnValues: tuple[Einsum | str] = () @classmethod - def from_children(cls, bodies: tuple[Einsum, ...], returnValue: tuple[Einsum | str]) -> Self: + def from_children(cls, bodies: tuple[EinsumPlanStatement, ...], returnValue: tuple[Einsum | str]) -> Self: return cls(bodies, returnValue) @property @@ -212,14 +253,14 @@ def rename_einsum(self, einsum: Einsum, new_alias: str, definitions: dict[str, E return einsum.rename(new_alias) def compile_plan(self, plan: Plan, parameters: dict[str, Table], definitions: dict[str, Einsum]) -> EinsumPlan: - einsums = [] + einsum_statements: list[EinsumPlanStatement] = [] returnValue = [] for body in plan.bodies: match body: case Plan(_): einsum_plan = self.compile_plan(body, parameters, definitions) - einsums.extend(einsum_plan.bodies) + einsum_statements.extend(einsum_plan.bodies) if einsum_plan.returnValues: if returnValue: @@ -228,22 +269,22 @@ def compile_plan(self, plan: Plan, parameters: dict[str, Table], definitions: di case Query(Alias(name), Table(_, _)): parameters[name] = body.rhs case Query(Alias(name), rhs): - einsums.append(self.rename_einsum(self.lower_to_einsum(rhs, einsums, parameters, definitions), name, definitions)) + einsum_statements.append(self.rename_einsum(self.lower_to_einsum(rhs, einsum_statements, parameters, definitions), name, definitions)) case Produces(args): if returnValue: raise Exception("Cannot invoke return more than once.") for arg in args: - returnValue.append(arg.name if isinstance(arg, Alias) else self.lower_to_einsum(arg, einsums, parameters, definitions)) + returnValue.append(arg.name if isinstance(arg, Alias) else self.lower_to_einsum(arg, einsum_statements, parameters, definitions)) case _: - einsums.append(self.rename_einsum(self.lower_to_einsum(body, einsums, parameters, definitions), self.get_next_alias(), definitions)) + einsum_statements.append(self.rename_einsum(self.lower_to_einsum(body, einsum_statements, parameters, definitions), self.get_next_alias(), definitions)) - return EinsumPlan(tuple(einsums), tuple(returnValue)) + return EinsumPlan(tuple(einsum_statements), tuple(returnValue)) - def lower_to_einsum(self, ex: LogicNode, einsums: list[Einsum], parameters: dict[str, Table], definitions: dict[str, Einsum]) -> Einsum: + def lower_to_einsum(self, ex: LogicNode, einsum_statements: list[EinsumPlanStatement], parameters: dict[str, Table], definitions: dict[str, Einsum]) -> Einsum: match ex: case Plan(_): plan = self.compile_plan(ex, parameters, definitions) - einsums.extend(plan.bodies) + einsum_statements.extend(plan.bodies) if plan.returnValues: raise Exception("Plans with no return value are not statements, but rather are expressions.") @@ -258,18 +299,18 @@ def lower_to_einsum(self, ex: LogicNode, einsums: list[Einsum], parameters: dict return plan.returnValues[0] case MapJoin(Literal(operation), args): - args = [self.lower_to_pointwise(arg, einsums, parameters, definitions) for arg in args] + args = [self.lower_to_pointwise(arg, einsum_statements, parameters, definitions) for arg in args] pointwise_expr = self.lower_to_pointwise_op(operation, args) #return Einsum(reduceOp=overwrite, input_fields=ex.fields, output_fields=ex.fields, pointwise_expr=pointwise_expr, output_alias=None) return Einsum(reduceOp=overwrite, output_fields=ex.fields, pointwise_expr=pointwise_expr, output_alias=None) case Reorder(arg, idxs): - return self.lower_to_einsum(arg, einsums, parameters, definitions).reorder(idxs) + return self.lower_to_einsum(arg, einsum_statements, parameters, definitions).reorder(idxs) case Aggregate(Literal(operation), Literal(init), arg, idxs): if init != init_value(operation, type(init)): raise Exception(f"Init value {init} is not the default value for operation {operation} of type {type(init)}. Non standard init values are not supported.") - pointwise_expr = self.lower_to_pointwise(arg, einsums, parameters, definitions) + pointwise_expr = self.lower_to_pointwise(arg, einsum_statements, parameters, definitions) #return Einsum(operation, arg.fields, ex.fields, pointwise_expr, self.get_next_alias()) - return Einsum(operation, ex.fields, pointwise_expr, self.get_next_alias()) + return Einsum(operation, ex.fields, pointwise_expr, self.get_next_alias(), None) case _: raise Exception(f"Unrecognized logic: {ex}") @@ -294,12 +335,12 @@ def lower_to_pointwise_op(self, operation: Callable, args: tuple[PointwiseNode, return result # lowers nested mapjoin logic IR nodes into a single pointwise expression - def lower_to_pointwise(self, ex: LogicNode, einsums: list[Einsum], parameters: dict[str, Table], definitions: dict[str, Einsum]) -> PointwiseNode: + def lower_to_pointwise(self, ex: LogicNode, einsum_statements: list[EinsumPlanStatement], parameters: dict[str, Table], definitions: dict[str, Einsum]) -> PointwiseNode: match ex: case Reorder(arg, idxs): - return self.lower_to_pointwise(arg, einsums, parameters, definitions) + return self.lower_to_pointwise(arg, einsum_statements, parameters, definitions) case MapJoin(Literal(operation), args): - args = [self.lower_to_pointwise(arg, einsums, parameters, definitions) for arg in args] + args = [self.lower_to_pointwise(arg, einsum_statements, parameters, definitions) for arg in args] return self.lower_to_pointwise_op(operation, args) case Relabel(Alias(name), idxs): # relable is really just a glorified pointwise access return PointwiseAccess(alias=name, idxs=idxs) @@ -307,7 +348,7 @@ def lower_to_pointwise(self, ex: LogicNode, einsums: list[Einsum], parameters: d return PointwiseLiteral(val=value) case Aggregate(_, _, _, _): # aggregate has to be computed seperatley as it's own einsum aggregate_einsum_alias = self.get_next_alias() - einsums.append(self.rename_einsum(self.lower_to_einsum(ex, einsums, parameters, definitions), aggregate_einsum_alias, definitions)) + einsum_statements.append(self.rename_einsum(self.lower_to_einsum(ex, einsum_statements, parameters, definitions), aggregate_einsum_alias, definitions)) return PointwiseAccess(alias=aggregate_einsum_alias, idxs=tuple(ex.fields)) case _: raise Exception(f"Unrecognized logic: {ex}") @@ -316,25 +357,46 @@ class EinsumCompiler: def __init__(self): self.el = EinsumLowerer() - def optimize_einsum(self, einsum_plan: EinsumPlan) -> EinsumPlan: - def optimize_sparse_einsum(einsum: Einsum, extra_ops: list[Einsum]) -> Einsum: - #match einsum: - # case Einsum(reduceOp=add, pointwise) - + def find_sparse_tensors(self, parameters: dict[str, Table])-> dict: # -> dict[str, Tuple[Field, ...]]: getting type errors here + from finchlite.autoschedule.sparse_tensor import SparseTensor + + sparse_tensors = dict() + for alias, value in parameters.items(): + match value: + case Table(tensor, idxs): + if isinstance(tensor, SparseTensor): + sparse_tensors[alias] = idxs + return sparse_tensors + + #getting type errors here if I use dict[str, Tuple[Field, ...]] + def optimize_einsum(self, einsum_plan: EinsumPlan, sparse_aliases: dict) -> EinsumPlan: + def optimize_sparse_einsum(einsum: Einsum, extra_ops: list[EinsumPlanStatement]) -> Einsum: return einsum - optimized_einsums = [] - for einsum in einsum_plan.bodies: - optimized_einsums.append(optimize_sparse_einsum(einsum, optimized_einsums)) + optimized_einsums: list[EinsumPlanStatement] = [] + for statement in einsum_plan.bodies: + match statement: + case Einsum(_, _, _, _, _): + optimized_einsums.append(optimize_sparse_einsum(statement, optimized_einsums)) + case _: + optimized_einsums.append(statement) - optimized_return = optimize_sparse_einsum(einsum_plan.returnValues[0], optimized_einsums) - return EinsumPlan(tuple(optimize_sparse_einsum), optimized_return) + optimized_returns = [] + for return_value in einsum_plan.returnValues: + match return_value: + case Einsum(_, _, _, _, _): + optimized_returns.append(optimize_sparse_einsum(return_value, optimized_einsums)) + case _: + optimized_returns.append(return_value) + return EinsumPlan(tuple(optimized_einsums), tuple(optimized_returns)) def __call__(self, prgm: Plan): parameters = {} definitions = {} einsum_plan = self.el(prgm, parameters, definitions) - einsum_plan = self.optimize_einsum(einsum_plan) + + sparse_aliases = self.find_sparse_tensors(parameters) + einsum_plan = self.optimize_einsum(einsum_plan, sparse_aliases) return einsum_plan, parameters, definitions @@ -389,18 +451,27 @@ def print_pointwise_expr(self, pointwise_expr: PointwiseNode): case PointwiseLiteral(val): return str(val) - def print_einsum(self, einsum: Einsum): + def print_einsum(self, einsum: Einsum) -> str: + if einsum.indirect_coo_alias: + return f"{einsum.output_alias}[{einsum.indirect_coo_alias}[{self.print_indicies(einsum.output_fields)}]] {self.print_reducer(einsum.reduceOp)} {self.print_pointwise_expr(einsum.pointwise_expr)}" return f"{einsum.output_alias}[{self.print_indicies(einsum.output_fields)}] {self.print_reducer(einsum.reduceOp)} {self.print_pointwise_expr(einsum.pointwise_expr)}" - def print_return_value(self, return_value: Einsum | str): + def print_return_value(self, return_value: Einsum | str) -> str: return return_value if isinstance(return_value, str) else self.print_einsum(return_value) - def print_einsum_plan(self, einsum_plan: EinsumPlan): + def print_einsum_plan_statement(self, einsum_plan_statement: EinsumPlanStatement) -> str: + match einsum_plan_statement: + case Einsum(_, _, _, _, _): + return self.print_einsum(einsum_plan_statement) + case _: + raise Exception(f"Unrecognized einsum plan statement: {einsum_plan_statement}") + + def print_einsum_plan(self, einsum_plan: EinsumPlan) -> str: if not einsum_plan.returnValues: return "\n".join([self.print_einsum(einsum) for einsum in einsum_plan.bodies]) - return f"{"\n".join([self.print_einsum(einsum) for einsum in einsum_plan.bodies])}\nreturn {", ".join([self.print_return_value(return_value) for return_value in einsum_plan.returnValues])}" + return f"{"\n".join([self.print_einsum_plan_statement(statement) for statement in einsum_plan.bodies])}\nreturn {", ".join([self.print_return_value(return_value) for return_value in einsum_plan.returnValues])}" - def __call__(self, prgm: EinsumPlan): + def __call__(self, prgm: EinsumPlan) -> str: return self.print_einsum_plan(prgm) class EinsumInterpreter: From 7d20a027df0af8b920a34ce94c7bec2939de8542 Mon Sep 17 00:00:00 2001 From: TheRealMichaelWang Date: Mon, 29 Sep 2025 21:15:39 -0400 Subject: [PATCH 26/26] * Added support for some indirect einsum operations * Still need to add support for true indirect pointwise access --- src/finchlite/autoschedule/einsum.py | 200 +++++++++++++++++---------- 1 file changed, 126 insertions(+), 74 deletions(-) diff --git a/src/finchlite/autoschedule/einsum.py b/src/finchlite/autoschedule/einsum.py index 9eb9c38c..1c1db625 100644 --- a/src/finchlite/autoschedule/einsum.py +++ b/src/finchlite/autoschedule/einsum.py @@ -1,14 +1,12 @@ -from ast import Tuple from dataclasses import dataclass from abc import ABC import operator -from turtle import st from typing import Callable, Self from finchlite.algebra.tensor import Tensor from finchlite.finch_logic import LogicNode, Field, Plan, Query, Alias, Literal, Relabel from finchlite.finch_logic.nodes import Aggregate, MapJoin, Produces, Reorder, Table -from finchlite.symbolic import Term, TermTree +from finchlite.symbolic import Term, TermTree, PostWalk, Rewrite from finchlite.algebra import is_commutative, overwrite, init_value, promote_max, promote_min import numpy as np @@ -59,7 +57,7 @@ def from_children(cls, alias: str, idxs: tuple[Field, ...]) -> Self: @property def children(self): - return [self.alias, *self.idxs] + return [self.alias, self.idxs] @dataclass(eq=True, frozen=True) @@ -89,7 +87,7 @@ def children(self): return [self.alias, self.coo_coord_alias, self.idx] @dataclass(eq=True, frozen=True) -class PointwiseOp(PointwiseNode): +class PointwiseOp(PointwiseNode, TermTree): """ PointwiseOp @@ -107,7 +105,7 @@ class PointwiseOp(PointwiseNode): # Children: The args @classmethod - def from_children(cls, op: Callable, args: tuple[PointwiseNode, ...]) -> Self: + def from_children(cls, op: Callable, *args: tuple[PointwiseNode, ...]) -> Self: return cls(op, args) @property @@ -181,9 +179,9 @@ class Einsum(EinsumPlanStatement, TermTree): indirect_coo_alias: str | None @classmethod - def from_children(cls, output_alias: str | None, updateOp: Callable, output_fields: tuple[Field, ...], pointwise_expr: PointwiseNode, indirect_coo_alias: str | None) -> Self: + def from_children(cls, reduceOp: Callable, output_fields: tuple[Field, ...], pointwise_expr: PointwiseNode, output_alias: str | None, indirect_coo_alias: str | None) -> Self: #return cls(output_alias, updateOp, input_fields, output_fields, pointwise_expr) - return cls(output_alias, updateOp, output_fields, pointwise_expr, indirect_coo_alias) + return cls(reduceOp, output_fields, pointwise_expr, output_alias, indirect_coo_alias) @property def children(self): @@ -199,11 +197,28 @@ def reorder(self, idxs: tuple[Field, ...]): return Einsum(self.reduceOp, idxs, self.pointwise_expr, self.output_alias, self.indirect_coo_alias) @dataclass(eq=True, frozen=True) -class ExtractCOO(EinsumPlanStatement): +class ExtractCOOFromSparse(EinsumPlanStatement): """ - ExtractCOO + ExtractCOOFromSparse - A plan statement that contains an extract's the COO matrix from a sparse tensor. + A plan statement that contains an extract's the coordinate array from COO sparse tensor. + """ + alias: str + + @classmethod + def from_children(cls, alias: str) -> Self: + return cls(alias) + + @property + def children(self): + return [self.alias] + +@dataclass(eq=True, frozen=True) +class ExtractValuesFromSparse(EinsumPlanStatement): + """ + ExtractValuesFromSparse + + A plan statement that contains an extract's the values array from COO sparse tensor. """ alias: str @@ -227,8 +242,8 @@ class EinsumPlan(Plan): returnValues: tuple[Einsum | str] = () @classmethod - def from_children(cls, bodies: tuple[EinsumPlanStatement, ...], returnValue: tuple[Einsum | str]) -> Self: - return cls(bodies, returnValue) + def from_children(cls, bodies: tuple[EinsumPlanStatement, ...], returnValues: tuple[Einsum | str]) -> Self: + return cls(bodies, returnValues) @property def children(self): @@ -286,7 +301,7 @@ def lower_to_einsum(self, ex: LogicNode, einsum_statements: list[EinsumPlanState plan = self.compile_plan(ex, parameters, definitions) einsum_statements.extend(plan.bodies) - if plan.returnValues: + if not plan.returnValues: raise Exception("Plans with no return value are not statements, but rather are expressions.") if len(plan.returnValues) > 1: @@ -302,7 +317,7 @@ def lower_to_einsum(self, ex: LogicNode, einsum_statements: list[EinsumPlanState args = [self.lower_to_pointwise(arg, einsum_statements, parameters, definitions) for arg in args] pointwise_expr = self.lower_to_pointwise_op(operation, args) #return Einsum(reduceOp=overwrite, input_fields=ex.fields, output_fields=ex.fields, pointwise_expr=pointwise_expr, output_alias=None) - return Einsum(reduceOp=overwrite, output_fields=ex.fields, pointwise_expr=pointwise_expr, output_alias=None) + return Einsum(reduceOp=overwrite, output_fields=tuple(ex.fields), pointwise_expr=pointwise_expr, output_alias=None, indirect_coo_alias=None) case Reorder(arg, idxs): return self.lower_to_einsum(arg, einsum_statements, parameters, definitions).reorder(idxs) case Aggregate(Literal(operation), Literal(init), arg, idxs): @@ -310,28 +325,29 @@ def lower_to_einsum(self, ex: LogicNode, einsum_statements: list[EinsumPlanState raise Exception(f"Init value {init} is not the default value for operation {operation} of type {type(init)}. Non standard init values are not supported.") pointwise_expr = self.lower_to_pointwise(arg, einsum_statements, parameters, definitions) #return Einsum(operation, arg.fields, ex.fields, pointwise_expr, self.get_next_alias()) - return Einsum(operation, ex.fields, pointwise_expr, self.get_next_alias(), None) + return Einsum(operation, tuple(ex.fields), pointwise_expr, self.get_next_alias(), None) case _: raise Exception(f"Unrecognized logic: {ex}") def lower_to_pointwise_op(self, operation: Callable, args: tuple[PointwiseNode, ...]) -> PointwiseOp: # if operation is commutative, we simply pass all the args to the pointwise op since order of args does not matter if is_commutative(operation): - ret_args = [] # flatten the args - for arg in args: - match arg: - case PointwiseOp(op2, _) if op2 == operation: - ret_args.extend(arg.args) - case _: - ret_args.append(arg) - - return PointwiseOp(operation, ret_args) + def flatten_args(m_args: tuple[PointwiseNode, ...]) -> tuple[PointwiseNode, ...]: + ret_args = [] + for arg in m_args: + match arg: + case PointwiseOp(op2, _) if op2 == operation: + ret_args.extend(flatten_args(arg.args)) + case _: + ret_args.append(arg) + return tuple(ret_args) + return PointwiseOp(operation, flatten_args(args)) # combine args from left to right (i.e a / b / c -> (a / b) / c) assert len(args) > 1 - result = PointwiseOp(operation, args[0], args[1]) + result = PointwiseOp(operation, (args[0], args[1])) for arg in args[2:]: - result = PointwiseOp(operation, result, arg) + result = PointwiseOp(operation, (result, arg)) return result # lowers nested mapjoin logic IR nodes into a single pointwise expression @@ -353,53 +369,6 @@ def lower_to_pointwise(self, ex: LogicNode, einsum_statements: list[EinsumPlanSt case _: raise Exception(f"Unrecognized logic: {ex}") -class EinsumCompiler: - def __init__(self): - self.el = EinsumLowerer() - - def find_sparse_tensors(self, parameters: dict[str, Table])-> dict: # -> dict[str, Tuple[Field, ...]]: getting type errors here - from finchlite.autoschedule.sparse_tensor import SparseTensor - - sparse_tensors = dict() - for alias, value in parameters.items(): - match value: - case Table(tensor, idxs): - if isinstance(tensor, SparseTensor): - sparse_tensors[alias] = idxs - return sparse_tensors - - #getting type errors here if I use dict[str, Tuple[Field, ...]] - def optimize_einsum(self, einsum_plan: EinsumPlan, sparse_aliases: dict) -> EinsumPlan: - def optimize_sparse_einsum(einsum: Einsum, extra_ops: list[EinsumPlanStatement]) -> Einsum: - return einsum - - optimized_einsums: list[EinsumPlanStatement] = [] - for statement in einsum_plan.bodies: - match statement: - case Einsum(_, _, _, _, _): - optimized_einsums.append(optimize_sparse_einsum(statement, optimized_einsums)) - case _: - optimized_einsums.append(statement) - - optimized_returns = [] - for return_value in einsum_plan.returnValues: - match return_value: - case Einsum(_, _, _, _, _): - optimized_returns.append(optimize_sparse_einsum(return_value, optimized_einsums)) - case _: - optimized_returns.append(return_value) - return EinsumPlan(tuple(optimized_einsums), tuple(optimized_returns)) - - def __call__(self, prgm: Plan): - parameters = {} - definitions = {} - einsum_plan = self.el(prgm, parameters, definitions) - - sparse_aliases = self.find_sparse_tensors(parameters) - einsum_plan = self.optimize_einsum(einsum_plan, sparse_aliases) - - return einsum_plan, parameters, definitions - class EinsumPrinterContext: def print_indicies(self, idxs: tuple[Field, ...]): return ", ".join([str(idx) for idx in idxs]) @@ -463,6 +432,10 @@ def print_einsum_plan_statement(self, einsum_plan_statement: EinsumPlanStatement match einsum_plan_statement: case Einsum(_, _, _, _, _): return self.print_einsum(einsum_plan_statement) + case ExtractCOOFromSparse(alias): + return f"{alias}_coo = ExtractCOO({alias})" + case ExtractValuesFromSparse(alias): + return f"{alias}_values = ExtractValues({alias})" case _: raise Exception(f"Unrecognized einsum plan statement: {einsum_plan_statement}") @@ -484,6 +457,85 @@ def print(self, einsum_plan: EinsumPlan, parameters: dict[str, Table]): print(einsum_plan) return (np.arange(6, dtype=np.float32).reshape(2, 3),) + + +class EinsumCompiler: + def __init__(self): + self.el = EinsumLowerer() + + def find_sparse_tensors(self, parameters: dict[str, Table])-> tuple[dict, dict]: # -> dict[str, Tuple[Field, ...]]: getting type errors here + from finchlite.autoschedule.sparse_tensor import SparseTensor + + special_field_number = 1 + sparse_tensors = dict() + sparse_fields = dict() + + for alias, value in parameters.items(): + match value: + case Table(Literal(tensor), idxs): + if isinstance(tensor, SparseTensor): + if idxs in sparse_fields: + sparse_tensors[alias] = sparse_fields[idxs][1] + else: + sparse_tensors[alias] = Field(f"sf_{special_field_number}") + sparse_fields[idxs] = (alias, Field(f"sf_{special_field_number}")) + special_field_number += 1 + return sparse_tensors, sparse_fields + + #getting type errors here if I use dict[str, Tuple[Field, ...]] + def optimize_einsum(self, einsum_plan: EinsumPlan, sparse_aliases: dict, sparse_fields: dict) -> EinsumPlan: + def optimize_sparse_einsum(einsum: Einsum, extra_ops: list[EinsumPlanStatement]) -> Einsum: + extracted_value_alias = set() + extracted_coo_alias = set() + + def optimize_pointwise_access(node: PointwiseNode) -> PointwiseNode: + match node: + case PointwiseAccess(alias, idxs): + if idxs in sparse_fields: + if alias in sparse_aliases: + if alias not in extracted_value_alias: + extracted_value_alias.add(alias) + extra_ops.append(ExtractValuesFromSparse(alias)) + return PointwiseAccess(f"{alias}_values", (sparse_aliases[alias],)) + else: + if alias not in extracted_coo_alias: + extracted_coo_alias.add(alias) + extra_ops.append(ExtractCOOFromSparse(alias)) + return PointwiseIndirectCOOAccess(alias, f"{sparse_fields[idxs][0]}_coo", sparse_fields[idxs][1]) + return node + + new_pointwise_expr = Rewrite(PostWalk(optimize_pointwise_access))(einsum.pointwise_expr) + + if einsum.output_fields in sparse_fields: + return Einsum(einsum.reduceOp, (sparse_fields[einsum.output_fields][1],), new_pointwise_expr, einsum.output_alias, f"{sparse_fields[einsum.output_fields][0]}_coo") + return Einsum(einsum.reduceOp, einsum.output_fields, new_pointwise_expr, einsum.output_alias, None) + + optimized_einsums: list[EinsumPlanStatement] = [] + for statement in einsum_plan.bodies: + match statement: + case Einsum(_, _, _, _, _): + optimized_einsums.append(optimize_sparse_einsum(statement, optimized_einsums)) + case _: + optimized_einsums.append(statement) + + optimized_returns = [] + for return_value in einsum_plan.returnValues: + match return_value: + case Einsum(_, _, _, _, _): + optimized_returns.append(optimize_sparse_einsum(return_value, optimized_einsums)) + case _: + optimized_returns.append(return_value) + return EinsumPlan(tuple(optimized_einsums), tuple(optimized_returns)) + + def __call__(self, prgm: Plan): + parameters = {} + definitions = {} + einsum_plan = self.el(prgm, parameters, definitions) + + sparse_aliases, sparse_fields = self.find_sparse_tensors(parameters) + einsum_plan = self.optimize_einsum(einsum_plan, sparse_aliases, sparse_fields) + + return einsum_plan, parameters, definitions class EinsumScheduler: def __init__(self, ctx: EinsumCompiler):