Skip to content

Commit 042e76e

Browse files
Bubble up errors and always call into Python for costs
1 parent ff280f8 commit 042e76e

File tree

4 files changed

+273
-117
lines changed

4 files changed

+273
-117
lines changed

python/egglog/bindings.pyi

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -771,18 +771,17 @@ class _Cost(Protocol):
771771
def __le__(self, other: _Cost) -> bool: ...
772772
def __gt__(self, other: _Cost) -> bool: ...
773773
def __ge__(self, other: _Cost) -> bool: ...
774-
def __add__(self, other: _Cost) -> _Cost: ...
775774

776775
_COST = TypeVar("_COST", bound=_Cost)
777776

778777
@final
779778
class CostModel(Generic[_COST]):
780779
def __init__(
781780
self,
782-
fold: Callable[[str, _COST, list[_COST]], _COST] | None,
783-
enode_cost: Callable[[str, list[Value]], _COST] | None,
784-
container_cost: Callable[[str, Value, list[_COST]], _COST] | None,
785-
base_value_cost: Callable[[str, Value], _COST] | None,
781+
fold: Callable[[str, _COST, list[_COST]], _COST],
782+
enode_cost: Callable[[str, list[Value]], _COST],
783+
container_cost: Callable[[str, Value, list[_COST]], _COST],
784+
base_value_cost: Callable[[str, Value], _COST],
786785
) -> None: ...
787786

788787
@final

python/egglog/egraph.py

Lines changed: 171 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections.abc import Callable, Generator, Iterable
99
from contextvars import ContextVar, Token
1010
from dataclasses import InitVar, dataclass, field
11-
from functools import partial
11+
from functools import cached_property, partial, total_ordering
1212
from inspect import Parameter, currentframe, signature
1313
from types import FrameType, FunctionType
1414
from typing import (
@@ -57,9 +57,12 @@
5757
"DefaultCostModel",
5858
"EGraph",
5959
"Expr",
60+
"ExprCallable",
6061
"Fact",
6162
"Fact",
6263
"GraphvizKwargs",
64+
"GreedyDagCost",
65+
"GreedyDagCostModel",
6366
"RewriteOrRule",
6467
"Ruleset",
6568
"Schedule",
@@ -959,7 +962,7 @@ def _facts_to_check(self, fact_likes: Iterable[FactLike]) -> bindings.Check:
959962

960963
@overload
961964
def extract(
962-
self, expr: BASE_EXPR, /, include_cost: Literal[False] = False, cost_model: CostModel[Cost] | None = None
965+
self, expr: BASE_EXPR, /, include_cost: Literal[False] = False, cost_model: CostModel | None = None
963966
) -> BASE_EXPR: ...
964967

965968
@overload
@@ -1978,75 +1981,125 @@ def get_cost(expr: BaseExpr) -> i64:
19781981
)
19791982

19801983

1981-
class Cost(Protocol):
1984+
class Comparable(Protocol):
19821985
def __lt__(self, other: Self) -> bool: ...
19831986
def __le__(self, other: Self) -> bool: ...
19841987
def __gt__(self, other: Self) -> bool: ...
19851988
def __ge__(self, other: Self) -> bool: ...
19861989

19871990

1988-
COST = TypeVar("COST", bound=Cost)
1991+
COST = TypeVar("COST", bound=Comparable)
19891992

19901993

19911994
class CostModel(Protocol[COST]):
19921995
"""
19931996
A cost model for an e-graph. Used to determine the cost of an expression based on its structure and the costs of its sub-expressions.
19941997
19951998
Subclass this and implement the methods to create a custom cost model.
1999+
2000+
Additionally, the cost model should guarantee that a term has a no-smaller cost
2001+
than its subterms to avoid cycles in the extracted terms for common case usages.
2002+
For more niche usages, a term can have a cost less than its subterms.
2003+
As long as there is no negative cost cycle, the default extractor is guaranteed to terminate in computing the costs.
2004+
However, the user needs to be careful to guarantee acyclicity in the extracted terms.
19962005
"""
19972006

19982007
@abstractmethod
19992008
def fold(self, callable: ExprCallable, children_costs: list[COST], head_cost: COST) -> COST:
20002009
"""
20012010
The total cost of a term given the cost of the root e-node and its immediate children's total costs.
20022011
"""
2012+
raise NotImplementedError
20032013

20042014
@abstractmethod
20052015
def call_cost(self, egraph: EGraph, expr: Expr) -> COST:
20062016
"""
20072017
The cost of an function call (without the cost of children).
20082018
"""
2019+
raise NotImplementedError
20092020

20102021
@abstractmethod
20112022
def container_cost(self, egraph: EGraph, expr: Container, element_costs: list[COST]) -> COST:
20122023
"""
20132024
The cost of a container value given the costs of its elements.
20142025
"""
2026+
raise NotImplementedError
20152027

20162028
@abstractmethod
20172029
def primitive_cost(self, egraph: EGraph, expr: Primitive) -> COST:
20182030
"""
20192031
The cost of a base value (like a literal or variable).
20202032
"""
2033+
raise NotImplementedError
20212034

20222035

2023-
class DefaultCostModel(CostModel[int]):
2024-
"""
2025-
A default cost model for an e-graph.
2036+
class ComparableAdd(Comparable, Protocol):
2037+
def __add__(self, other: Self) -> Self: ...
20262038

2027-
Subclass this to extend the default integer cost model.
2039+
2040+
BASE_COST = TypeVar("BASE_COST", bound=ComparableAdd)
2041+
2042+
2043+
class BaseCostModel(CostModel[BASE_COST]):
20282044
"""
2045+
Base cost model which provides default implementations for some methods, if the cost can be added and a 0 and 1 exist.
2046+
"""
2047+
2048+
@property
2049+
@abstractmethod
2050+
def identity(self) -> BASE_COST:
2051+
"""
2052+
Identity element, such that COST + identity = COST.
2053+
2054+
Usually zero.
2055+
"""
2056+
raise NotImplementedError
2057+
2058+
@property
2059+
@abstractmethod
2060+
def unit(self) -> BASE_COST:
2061+
"""
2062+
Unit element, default cost for node with no children, such that COST + unit > COST
2063+
"""
2064+
raise NotImplementedError
20292065

2030-
def fold(self, callable: ExprCallable, children_costs: list[int], head_cost: int) -> int:
2066+
def fold(self, callable: ExprCallable, children_costs: list[BASE_COST], head_cost: BASE_COST) -> BASE_COST:
20312067
"""
20322068
The total cost of a term given the cost of the root e-node and its immediate children's total costs.
20332069
"""
20342070
return sum(children_costs, start=head_cost)
20352071

2036-
def container_cost(self, egraph: EGraph, expr: Container, element_costs: list[int]) -> int:
2072+
def call_cost(self, egraph: EGraph, expr: Expr) -> BASE_COST:
20372073
"""
2038-
The cost of a container value given the costs of its elements.
2074+
The cost of an function call (without the cost of children).
2075+
"""
2076+
return self.unit
20392077

2040-
The default cost for containers is just the sum of all the elements inside
2078+
def container_cost(self, egraph: EGraph, expr: Container, element_costs: list[BASE_COST]) -> BASE_COST:
2079+
"""
2080+
The cost of a container value given the costs of its elements.
20412081
"""
2042-
return sum(element_costs)
2082+
return sum(element_costs, start=self.identity)
20432083

2044-
def primitive_cost(self, egraph: EGraph, expr: Primitive) -> int:
2084+
def primitive_cost(self, egraph: EGraph, expr: Primitive) -> BASE_COST:
20452085
"""
20462086
The cost of a base value (like a literal or variable).
20472087
"""
2048-
return 1
2088+
return self.unit
2089+
20492090

2091+
class DefaultCostModel(BaseCostModel[int]):
2092+
"""
2093+
A default cost model for an e-graph, which looks up costs set on function calls, or uses 1 as the default cost.
2094+
2095+
Subclass this to extend the default integer cost model.
2096+
"""
2097+
2098+
# TODO: Make cost model take identity and unit as args
2099+
identity = 0
2100+
unit = 1
2101+
2102+
# TODO: rename expr cost?
20502103
def call_cost(self, egraph: EGraph, expr: Expr) -> int:
20512104
"""
20522105
The cost of an enode is either the cost set on it, or the cost of the callable, or 1 if neither are set.
@@ -2065,6 +2118,108 @@ def call_cost(self, egraph: EGraph, expr: Expr) -> int:
20652118
return get_callable_cost(callable_fn) or 1
20662119

20672120

2121+
class ComparableAddSub(ComparableAdd, Protocol):
2122+
def __sub__(self, other: Self) -> Self: ...
2123+
2124+
2125+
DAG_COST = TypeVar("DAG_COST", bound=ComparableAddSub)
2126+
2127+
2128+
@total_ordering
2129+
@dataclass
2130+
class GreedyDagCost(Generic[DAG_COST]):
2131+
expr: BaseExpr
2132+
costs: dict[BaseExpr, DAG_COST]
2133+
identity: DAG_COST
2134+
2135+
def __eq__(self, other: object) -> bool:
2136+
if not isinstance(other, GreedyDagCost):
2137+
return NotImplemented
2138+
return self.total == other.total
2139+
2140+
def __lt__(self, other: GreedyDagCost) -> bool:
2141+
return self.total < other.total
2142+
2143+
@cached_property
2144+
def total(self) -> DAG_COST:
2145+
return sum(self.costs.values(), start=self.identity)
2146+
2147+
@classmethod
2148+
def from_children(
2149+
cls,
2150+
expr: BaseExpr,
2151+
children: list[GreedyDagCost[DAG_COST]],
2152+
self_and_children: DAG_COST,
2153+
identity: DAG_COST,
2154+
) -> GreedyDagCost[DAG_COST]:
2155+
"""
2156+
Create a GreedyDagCost from the costs of its children and the cost of itself and its children.
2157+
2158+
Make sure to subtract the costs of the children from self_and_children to get the cost of the node itself.
2159+
"""
2160+
costs: dict[BaseExpr, DAG_COST] = {}
2161+
for c in children:
2162+
for k, v in c.costs.items():
2163+
if k in costs:
2164+
assert costs[k] == v, f"Conflicting costs for {k}: {costs[k]} and {v}"
2165+
else:
2166+
costs[k] = v
2167+
for c in children:
2168+
self_and_children -= c.total
2169+
costs[expr] = self_and_children
2170+
return cls(expr, costs, identity)
2171+
2172+
def __str__(self) -> str:
2173+
return f"GreedyDagCost(total={self.total})"
2174+
2175+
def __repr__(self) -> str:
2176+
return str(self)
2177+
2178+
2179+
@dataclass
2180+
class GreedyDagCostModel(CostModel[GreedyDagCost[DAG_COST]]):
2181+
"""
2182+
A cost model which will count duplicate nodes only once.
2183+
2184+
Should have similar behavior as https://github.com/egraphs-good/extraction-gym/blob/main/src/extract/greedy_dag.rs
2185+
but implemented as a cost model that will be used with the default extractor.
2186+
"""
2187+
2188+
base: BaseCostModel[DAG_COST]
2189+
2190+
def fold(
2191+
self,
2192+
callable: ExprCallable,
2193+
children_costs: list[GreedyDagCost[DAG_COST]],
2194+
head_cost: GreedyDagCost[DAG_COST],
2195+
) -> GreedyDagCost[DAG_COST]:
2196+
# head cost.total is the same as head_cost.costs[head_cost.expr] because it come from call_cost which always has one cost
2197+
base_fold = self.base.fold(callable, [c.total for c in children_costs], head_cost.total)
2198+
return GreedyDagCost[DAG_COST].from_children(head_cost.expr, children_costs, base_fold, self.base.identity)
2199+
2200+
def call_cost(self, egraph: EGraph, expr: Expr) -> GreedyDagCost[DAG_COST]:
2201+
"""
2202+
The cost of an function call (without the cost of children).
2203+
"""
2204+
return GreedyDagCost(expr, {expr: self.base.call_cost(egraph, expr)}, self.base.identity)
2205+
2206+
def container_cost(
2207+
self, egraph: EGraph, expr: Container, element_costs: list[GreedyDagCost[DAG_COST]]
2208+
) -> GreedyDagCost[DAG_COST]:
2209+
"""
2210+
The cost of a container value given the costs of its elements.
2211+
"""
2212+
base_container_cost = self.base.container_cost(egraph, expr, [c.total for c in element_costs])
2213+
return GreedyDagCost[DAG_COST].from_children(expr, element_costs, base_container_cost, self.base.identity)
2214+
2215+
def primitive_cost(self, egraph: EGraph, expr: Primitive) -> GreedyDagCost[DAG_COST]:
2216+
"""
2217+
The cost of a base value (like a literal or variable).
2218+
"""
2219+
cost = self.base.primitive_cost(egraph, expr)
2220+
return GreedyDagCost(expr, {expr: cost}, self.base.identity)
2221+
2222+
20682223
def get_callable_cost(fn: ExprCallable) -> int | None:
20692224
"""
20702225
Returns the cost of a callable, if it has one set. Otherwise returns None.
@@ -2119,12 +2274,4 @@ def container_cost(self, tp: str, value: bindings.Value, element_costs: list[COS
21192274
return self.model.container_cost(self.egraph, cast("Container", expr), element_costs)
21202275

21212276
def to_bindings_cost_model(self) -> bindings.CostModel:
2122-
model_tp = type(self.model)
2123-
# Use custom costs if we have overriden them, otherwise use None to use the default in Rust for faster performance
2124-
fold = self.fold if model_tp.fold is not DefaultCostModel.fold else None
2125-
enode_cost = self.enode_cost if model_tp.call_cost is not DefaultCostModel.call_cost else None
2126-
container_cost = self.container_cost if model_tp.container_cost is not DefaultCostModel.container_cost else None
2127-
base_value_cost = (
2128-
self.base_value_cost if model_tp.primitive_cost is not DefaultCostModel.primitive_cost else None
2129-
)
2130-
return bindings.CostModel(fold, enode_cost, container_cost, base_value_cost)
2277+
return bindings.CostModel(self.fold, self.enode_cost, self.container_cost, self.base_value_cost)

python/tests/test_high_level.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,3 +1311,14 @@ def fold(self, callable, child_costs, head_cost):
13111311
fold_spy.assert_any_call(my_f, [5], 1)
13121312
enode_cost_spy.assert_any_call(egraph, E())
13131313
enode_cost_spy.assert_any_call(egraph, call)
1314+
1315+
def test_errors_bubble(self):
1316+
class MyCostModel(DefaultCostModel):
1317+
def primitive_cost(self, egraph, expr):
1318+
msg = "bad"
1319+
raise ValueError(msg)
1320+
1321+
egraph = EGraph()
1322+
1323+
with pytest.raises(ValueError, match="bad"):
1324+
egraph.extract(i64(10), cost_model=MyCostModel())

0 commit comments

Comments
 (0)