diff --git a/src/finchlite/algebra/__init__.py b/src/finchlite/algebra/__init__.py index f15d5640..292f9863 100644 --- a/src/finchlite/algebra/__init__.py +++ b/src/finchlite/algebra/__init__.py @@ -15,12 +15,14 @@ ) from .operator import ( InitWrite, + cansplitpush, conjugate, first_arg, identity, overwrite, promote_max, promote_min, + repeat_operator, ) from .tensor import ( Tensor, @@ -37,6 +39,7 @@ "Tensor", "TensorFType", "TensorPlaceholder", + "cansplitpush", "conjugate", "element_type", "fill_value", @@ -56,6 +59,7 @@ "promote_type", "query_property", "register_property", + "repeat_operator", "return_type", "shape_type", ] diff --git a/src/finchlite/algebra/algebra.py b/src/finchlite/algebra/algebra.py index 30d5c39e..2093ac4c 100644 --- a/src/finchlite/algebra/algebra.py +++ b/src/finchlite/algebra/algebra.py @@ -434,6 +434,9 @@ def is_distributive(op, other_op): (np.logical_or, lambda op, other_op: other_op == np.logical_and), (operator.pow, lambda op, other_op: False), (operator.truediv, lambda op, other_op: False), + (operator.add, lambda op, other_op: False), + (max, lambda op, other_op: False), + (min, lambda op, other_op: False), ]: register_property(fn, "__call__", "is_distributive", func) diff --git a/src/finchlite/algebra/operator.py b/src/finchlite/algebra/operator.py index 2384e324..ac32d233 100644 --- a/src/finchlite/algebra/operator.py +++ b/src/finchlite/algebra/operator.py @@ -1,4 +1,8 @@ +import math +import operator + from . import algebra +from .algebra import is_associative, is_commutative, is_idempotent def and_test(a, b): @@ -147,3 +151,73 @@ def identity(x): "return_type", lambda op, x: x, ) + + +def repeat_operator(x): + """ + If there exists an operator g such that + f(x, x, ..., x) (n times) is equal to g(x, n), + then return g. + """ + if not callable(x): + raise TypeError("Can't check repeat operator of non-callable objects!") + + if is_idempotent(x): + return None + + if x is operator.add: + return operator.mul + + if x is operator.mul: + return math.exp + + return None + + +for fn in [ + operator.and_, + operator.or_, + min, + max, +]: + algebra.register_property( + fn, + "__call__", + "repeat_operator", + lambda op: None, + ) + +algebra.register_property( + operator.add, + "__call__", + "repeat_operator", + lambda op: operator.mul, +) + +algebra.register_property( + operator.mul, + "__call__", + "repeat_operator", + lambda op: math.exp, +) + + +def cansplitpush(x, y): + """ + Return True if a reduction with operator `x` can be 'split-pushed' through + a pointwise operator `y`. + + We allow split-push when: + - x has a known repeat operator (repeat_operator(x) is not None), + - x and y are the same operator, + - and x is both commutative and associative. + """ + if not callable(x) or not callable(y): + raise TypeError("Can't check splitpush of non-callable operators!") + + return ( + repeat_operator(x) is not None + and x == y + and is_commutative(x) + and is_associative(x) + ) diff --git a/src/finchlite/galley/LogicalOptimizer/__init__.py b/src/finchlite/galley/LogicalOptimizer/__init__.py index 8c3598b6..3722fa71 100644 --- a/src/finchlite/galley/LogicalOptimizer/__init__.py +++ b/src/finchlite/galley/LogicalOptimizer/__init__.py @@ -1,13 +1,17 @@ from .annotated_query import ( AnnotatedQuery, + find_lowest_roots, get_idx_connected_components, get_reducible_idxs, + replace_and_remove_nodes, ) from .logic_to_stats import insert_statistics __all__ = [ "AnnotatedQuery", + "find_lowest_roots", "get_idx_connected_components", "get_reducible_idxs", "insert_statistics", + "replace_and_remove_nodes", ] diff --git a/src/finchlite/galley/LogicalOptimizer/annotated_query.py b/src/finchlite/galley/LogicalOptimizer/annotated_query.py index 91acaed3..cd5c7379 100644 --- a/src/finchlite/galley/LogicalOptimizer/annotated_query.py +++ b/src/finchlite/galley/LogicalOptimizer/annotated_query.py @@ -1,28 +1,41 @@ from collections import OrderedDict -from collections.abc import Iterable +from collections.abc import Collection, Iterable, Mapping from dataclasses import dataclass from typing import Any +from finchlite.algebra import ( + cansplitpush, + is_distributive, +) from finchlite.finch_logic import ( + Aggregate, Alias, + Field, + Literal, + LogicExpression, LogicNode, + MapJoin, + Plan, + Query, + Table, ) +from finchlite.galley.TensorStats import TensorStats @dataclass class AnnotatedQuery: - ST: type + ST: type[TensorStats] output_name: Alias | None - reduce_idxs: list[str] - point_expr: "LogicNode" - idx_lowest_root: OrderedDict[str, LogicNode] - idx_op: OrderedDict[str, Any] - idx_init: OrderedDict[str, Any] - parent_idxs: OrderedDict[str, list[str]] - original_idx: OrderedDict[str, str] - connected_components: list[list[str]] - connected_idxs: OrderedDict[str, set[str]] - output_order: list[str] | None = None + reduce_idxs: list[Field] + point_expr: LogicNode + idx_lowest_root: OrderedDict[Field, LogicExpression] + idx_op: OrderedDict[Field, Any] + idx_init: OrderedDict[Field, Any] + parent_idxs: OrderedDict[Field, list[Field]] + original_idx: OrderedDict[Field, Field] + connected_components: list[list[Field]] + connected_idxs: OrderedDict[Field, set[Field]] + output_order: list[Field] | None = None output_format: list[Any] | None = None @@ -30,24 +43,25 @@ def copy_aq(aq: AnnotatedQuery) -> AnnotatedQuery: """ Make a structured copy of an AnnotatedQuery. """ - return AnnotatedQuery( - ST=aq.ST, - output_name=aq.output_name, - reduce_idxs=list(aq.reduce_idxs), - point_expr=aq.point_expr, - idx_lowest_root=aq.idx_lowest_root.copy(), - idx_op=OrderedDict(aq.idx_op.items()), - idx_init=OrderedDict(aq.idx_init.items()), - parent_idxs=OrderedDict((m, list(n)) for m, n in aq.parent_idxs.items()), - original_idx=OrderedDict(aq.original_idx.items()), - connected_components=[list(n) for n in aq.connected_components], - connected_idxs=OrderedDict((m, set(n)) for m, n in aq.connected_idxs.items()), - output_order=None if aq.output_order is None else list(aq.output_order), - output_format=None if aq.output_format is None else list(aq.output_format), - ) + new = object.__new__(AnnotatedQuery) + new.ST = aq.ST + new.output_name = aq.output_name + new.point_expr = aq.point_expr + new.reduce_idxs = list(aq.reduce_idxs) + new.idx_lowest_root = OrderedDict(aq.idx_lowest_root.items()) + new.idx_op = OrderedDict(aq.idx_op.items()) + new.idx_init = OrderedDict(aq.idx_init.items()) + new.parent_idxs = OrderedDict((m, list(n)) for m, n in aq.parent_idxs.items()) + new.original_idx = OrderedDict(aq.original_idx.items()) + new.connected_components = [list(n) for n in aq.connected_components] + new.connected_idxs = OrderedDict((m, set(n)) for m, n in aq.connected_idxs.items()) + new.output_order = None if aq.output_order is None else list(aq.output_order) + new.output_format = None if aq.output_format is None else list(aq.output_format) + return new -def get_reducible_idxs(aq: AnnotatedQuery) -> list[str]: + +def get_reducible_idxs(aq: AnnotatedQuery) -> list[Field]: """ Indices eligible to be reduced immediately (no parents). @@ -58,44 +72,45 @@ def get_reducible_idxs(aq: AnnotatedQuery) -> list[str]: Returns ------- - list[str] - Indices in `aq.reduce_idxs` with zero parents. + list[Field] + Field objects in `aq.reduce_idxs` with zero parents. """ return [idx for idx in aq.reduce_idxs if len(aq.parent_idxs.get(idx, [])) == 0] def get_idx_connected_components( - parent_idxs: dict[str, Iterable[str]], - connected_idxs: dict[str, Iterable[str]], -) -> list[list[str]]: + parent_idxs: Mapping[Field, Iterable[Field]], + connected_idxs: Mapping[Field, Iterable[Field]], +) -> list[list[Field]]: """ - Compute connected components of indices and order those components by - parent/child constraints. + Compute connected components of indices (Field objects) and order those + components by parent/child constraints. Parameters ---------- - parent_idxs : Dict[str, Iterable[str]] - Mapping from an index to the set/iterable of its parent indices. - connected_idxs : Dict[str, Iterable[str]] - Mapping from an index to the set/iterable of indices considered + parent_idxs : Dict[Field, Iterable[Field]] + Mapping from an index to the iterable of its parent indices. + connected_idxs : Dict[Field, Iterable[Field]] + Mapping from an index to the iterable of indices considered "connected" to it (undirected neighbors). Only connections between non-parent pairs are used to form components. Returns ------- - List[List[str]] - A list of components, each a list of index names. Components are + List[List[Field]] + A list of components, each a list of Field objects. Components are ordered so that any component containing a parent appears before any component containing its child. """ - parent_map = {k: set(v) for k, v in parent_idxs.items()} - conn_map: OrderedDict[str, set[str]] = OrderedDict( + parent_map: dict[Field, set[Field]] = {k: set(v) for k, v in parent_idxs.items()} + conn_map: OrderedDict[Field, set[Field]] = OrderedDict( (k, set(v)) for k, v in connected_idxs.items() ) - component_ids: OrderedDict[str, int] = OrderedDict( + component_ids: OrderedDict[Field, int] = OrderedDict( (x, i) for i, x in enumerate(conn_map.keys()) ) + finished = False while not finished: finished = True @@ -111,12 +126,12 @@ def get_idx_connected_components( component_ids[idx1] = min(component_ids[idx2], component_ids[idx1]) unique_ids = list(OrderedDict.fromkeys(component_ids[idx] for idx in conn_map)) - components: list[list[str]] = [] + components: list[list[Field]] = [] for id in unique_ids: members = [idx for idx in conn_map if component_ids[idx] == id] components.append(members) - component_order: OrderedDict[tuple, int] = OrderedDict( + component_order: OrderedDict[tuple[Field, ...], int] = OrderedDict( (tuple(c), i) for i, c in enumerate(components) ) @@ -153,3 +168,109 @@ def get_idx_connected_components( components.sort(key=lambda c: component_order[tuple(c)]) return components + + +def replace_and_remove_nodes( + expr: LogicNode, + node_to_replace: LogicExpression, + new_node: LogicExpression, + nodes_to_remove: Collection[LogicExpression], +) -> LogicNode: + """ + Replace and/or remove arguments of a pointwise MapJoin expression. + + Parameters + ---------- + expr : LogicNode + The expression to transform. Typically a `MapJoin` in a pointwise + subexpression. + node_to_replace : LogicNode + The node to replace when it appears as an argument to `expr`, or as + `expr` itself. + new_node : LogicNode + The node that replaces `node_to_replace` wherever it is found. + nodes_to_remove : Collection[LogicNode] + A collection of nodes that, if present as arguments to a `MapJoin`, + should be removed from its argument list. + + Returns + ------- + LogicNode + A new `MapJoin` node with updated arguments if `expr` is a `MapJoin`, + `new_node` if `expr == node_to_replace`, or the original `expr` + otherwise. + """ + if expr == node_to_replace: + return new_node + + if isinstance(expr, (Plan, Query, Aggregate)): + raise ValueError( + f"There should be no {type(expr).__name__} nodes in a pointwise expression." + ) + + if isinstance(expr, MapJoin): + nodes_to_remove = set(nodes_to_remove) + new_args: list[LogicExpression] = [] + + for arg in expr.args: + if arg in nodes_to_remove: + continue + if arg == node_to_replace: + arg = new_node + new_args.append(arg) + + return MapJoin(expr.op, tuple(new_args)) + return expr + + +def find_lowest_roots( + op: Literal, idx: Field, root: LogicExpression +) -> list[LogicExpression]: + """ + Compute the lowest MapJoin / leaf nodes that a reduction over `idx` can be + safely pushed down to in a logical expression. + + Parameters + ---------- + op : Literal + The reduction operator node (e.g., Literal(operator.add)) + that we are trying to push down. + idx : Field + The index (dimension) being reduced over. + root : LogicExpression + The root logical expression under which we search for the lowest + pushdown positions for the reduction. + + Returns + ------- + list[LogicExpression] + ` A list of expression nodes representing the lowest positions in + the expression tree where the reduction over `idx` with operator + `op` can be safely pushed down. + """ + + if isinstance(root, MapJoin): + if not isinstance(root.op, Literal): + raise TypeError( + f"Expected MapJoin.op to be a Literal, got {type(root.op).__name__}" + ) + args_with = [arg for arg in root.args if idx in arg.fields] + args_without = [arg for arg in root.args if idx not in arg.fields] + + if len(args_with) == 1 and is_distributive(root.op.val, op.val): + return find_lowest_roots(op, idx, args_with[0]) + + if cansplitpush(op.val, root.op.val): + roots_without: list[LogicExpression] = list(args_without) + roots_with: list[LogicExpression] = [] + for arg in args_with: + roots_with.extend(find_lowest_roots(op, idx, arg)) + return roots_without + roots_with + return [root] + + if isinstance(root, (Alias, Table)): + return [root] + + raise ValueError( + f"There shouldn't be nodes of type {type(root).__name__} during root pushdown." + ) diff --git a/tests/test_algebra.py b/tests/test_algebra.py index ee9c7b5d..2a609f97 100644 --- a/tests/test_algebra.py +++ b/tests/test_algebra.py @@ -4,12 +4,14 @@ import numpy as np from finchlite.algebra import ( + cansplitpush, init_value, is_annihilator, is_associative, is_distributive, is_idempotent, is_identity, + repeat_operator, ) @@ -67,3 +69,9 @@ def test_algebra_selected(): assert is_idempotent(operator.xor) is False assert is_idempotent(np.logical_xor) is False assert is_idempotent(np.logaddexp) is False + assert repeat_operator(operator.add) is operator.mul + assert repeat_operator(operator.mul) is math.exp + assert repeat_operator(operator.and_) is None + assert cansplitpush(operator.add, operator.add) is True + assert cansplitpush(operator.add, operator.mul) is False + assert cansplitpush(operator.and_, operator.and_) is False diff --git a/tests/test_galley.py b/tests/test_galley.py index 7bf23ca3..464a4a87 100644 --- a/tests/test_galley.py +++ b/tests/test_galley.py @@ -14,9 +14,11 @@ ) from finchlite.galley.LogicalOptimizer import ( AnnotatedQuery, + find_lowest_roots, get_idx_connected_components, get_reducible_idxs, insert_statistics, + replace_and_remove_nodes, ) from finchlite.galley.TensorStats import DC, DCStats, DenseStats, TensorDef @@ -1311,23 +1313,35 @@ def test_varied_reduce_DC_card(dims, dcs, reduce_indices, expected_nnz): ], ) def test_get_reducible_idxs(reduce_idxs, parent_idxs, expected): - aq = AnnotatedQuery( - ST=object, - output_name=None, - reduce_idxs=list(reduce_idxs), - point_expr=None, - idx_lowest_root=OrderedDict(), - idx_op=OrderedDict(), - idx_init=OrderedDict(), - parent_idxs=OrderedDict((k, list(v)) for k, v in parent_idxs.items()), - original_idx=OrderedDict(), - connected_components=[], - connected_idxs=OrderedDict(), - output_order=None, - output_format=None, - ) - - assert get_reducible_idxs(aq) == expected + names = set(reduce_idxs) + names.update(parent_idxs.keys()) + for i in parent_idxs.values(): + names.update(i) + + fields: dict[str, Field] = {x: Field(x) for x in names} + reduce_fields: list[Field] = [fields[name] for name in reduce_idxs] + parent_fields: OrderedDict[Field, list[Field]] = OrderedDict( + (fields[key], [fields[p] for p in parents]) + for key, parents in parent_idxs.items() + ) + + aq = object.__new__(AnnotatedQuery) + aq.ST = object + aq.output_name = None + aq.reduce_idxs = reduce_fields + aq.point_expr = None + aq.idx_lowest_root = OrderedDict() + aq.idx_op = OrderedDict() + aq.idx_init = OrderedDict() + aq.parent_idxs = parent_fields + aq.original_idx = OrderedDict() + aq.connected_components = [] + aq.connected_idxs = OrderedDict() + aq.output_order = None + aq.output_format = None + + result = [field.name for field in get_reducible_idxs(aq)] + assert result == expected @pytest.mark.parametrize( @@ -1372,5 +1386,154 @@ def test_get_reducible_idxs(reduce_idxs, parent_idxs, expected): ], ) def test_get_idx_connected_components(parent_idxs, connected_idxs, expected): - out = get_idx_connected_components(parent_idxs, connected_idxs) - assert out == expected + names: set[str] = set(parent_idxs.keys()) | set(connected_idxs.keys()) + for i in parent_idxs.values(): + names.update(i) + for j in connected_idxs.values(): + names.update(j) + + name = {x: Field(x) for x in names} + + parent_field_idxs: dict[Field, list[Field]] = { + name[k]: [name[p] for p in v] for k, v in parent_idxs.items() + } + connected_field_idxs: dict[Field, list[Field]] = { + name[k]: [name[n] for n in v] for k, v in connected_idxs.items() + } + + components = get_idx_connected_components(parent_field_idxs, connected_field_idxs) + result = [[field.name for field in comp] for comp in components] + + assert result == expected + + +@pytest.mark.parametrize( + "arg_names,node_to_replace,nodes_to_remove,expected_names", + [ + (["a", "b", "c"], "b", [], ["a", "a", "c"]), + (["a", "b", "c"], "b", ["c"], ["a", "a"]), + (["a", "b", "c"], "c", ["c"], ["a", "b"]), + (["a", "b", "c"], "b", ["b"], ["a", "c"]), + (["a", "b", "c"], "c", [], ["a", "b", "a"]), + ], +) +def test_replace_and_remove_nodes( + arg_names, + node_to_replace, + nodes_to_remove, + expected_names, +): + args = [Table(Literal(name), (Field(name),)) for name in arg_names] + node_to_replace_node = next( + tbl for tbl in args if tbl.idxs[0].name == node_to_replace + ) + new_node = args[0] + + nodes_to_remove_nodes = {tbl for tbl in args if tbl.idxs[0].name in nodes_to_remove} + + expr_node = MapJoin(Literal("op"), tuple(args)) + + out = replace_and_remove_nodes( + expr=expr_node, + node_to_replace=node_to_replace_node, + new_node=new_node, + nodes_to_remove=nodes_to_remove_nodes, + ) + result = [tbl.idxs[0].name for tbl in out.args] + + assert result == expected_names + + +@pytest.mark.parametrize( + "root, idx_name, expected", + [ + # Distributive case: + # root = MapJoin(mul, [A(i), B(j)]), reduce over j → [B] + ( + MapJoin( + Literal(op.mul), + ( + Table(Literal("A"), (Field("i"),)), + Table(Literal("B"), (Field("j"),)), + ), + ), + "j", + ["B"], + ), + # Split-push case: + # root = MapJoin(add, [A(i), B(i), C(j)]), reduce over i → [C, A, B] + ( + MapJoin( + Literal(op.add), + ( + Table(Literal("A"), (Field("i"),)), + Table(Literal("B"), (Field("i"),)), + Table(Literal("C"), (Field("j"),)), + ), + ), + "i", + ["C", "A", "B"], + ), + # Leaf case: + # root = Table(A(i)), reduce over i → [A] + ( + Table(Literal("A"), (Field("i"),)), + "i", + ["A"], + ), + # Nested case: + # root = MapJoin(mul, [A(i,j), B(i)]), reduce over i → [A] + ( + MapJoin( + Literal(op.mul), + ( + Table(Literal("A"), (Field("i"), Field("j"))), + Table(Literal("B"), (Field("j"),)), + ), + ), + "i", + ["A"], + ), + # Special case: max(C(i), D(j)), reduce over i → [max(C,D)] + ( + MapJoin( + Literal(max), + ( + Table(Literal("C"), (Field("i"),)), + Table(Literal("D"), (Field("j"),)), + ), + ), + "i", + # expected root is the entire max node + [ + MapJoin( + Literal(max), + ( + Table(Literal("C"), (Field("i"),)), + Table(Literal("D"), (Field("j"),)), + ), + ) + ], + ), + ], +) +def test_find_lowest_roots(root, idx_name, expected): + roots = find_lowest_roots(Literal(op.add), Field(idx_name), root) + + # Special-case: the max(C(i), D(j)) example – we expect the MapJoin itself. + if ( + isinstance(root, MapJoin) + and isinstance(root.op, Literal) + and root.op.val is max + and idx_name == "i" + ): + assert roots == expected + else: + # All other cases: + result: list[str] = [] + for node in roots: + assert isinstance(node, Table) + assert isinstance(node.tns, Literal) + result.append(node.tns.val) + + assert result == expected