88from collections .abc import Callable , Generator , Iterable
99from contextvars import ContextVar , Token
1010from dataclasses import InitVar , dataclass , field
11- from functools import partial
11+ from functools import cached_property , partial , total_ordering
1212from inspect import Parameter , currentframe , signature
1313from types import FrameType , FunctionType
1414from typing import (
5757 "DefaultCostModel" ,
5858 "EGraph" ,
5959 "Expr" ,
60+ "ExprCallable" ,
6061 "Fact" ,
6162 "Fact" ,
6263 "GraphvizKwargs" ,
64+ "GreedyDagCost" ,
65+ "GreedyDagCostModel" ,
6366 "RewriteOrRule" ,
6467 "Ruleset" ,
6568 "Schedule" ,
@@ -959,7 +962,7 @@ def _facts_to_check(self, fact_likes: Iterable[FactLike]) -> bindings.Check:
959962
960963 @overload
961964 def extract (
962- self , expr : BASE_EXPR , / , include_cost : Literal [False ] = False , cost_model : CostModel [ Cost ] | None = None
965+ self , expr : BASE_EXPR , / , include_cost : Literal [False ] = False , cost_model : CostModel | None = None
963966 ) -> BASE_EXPR : ...
964967
965968 @overload
@@ -1978,75 +1981,125 @@ def get_cost(expr: BaseExpr) -> i64:
19781981 )
19791982
19801983
1981- class Cost (Protocol ):
1984+ class Comparable (Protocol ):
19821985 def __lt__ (self , other : Self ) -> bool : ...
19831986 def __le__ (self , other : Self ) -> bool : ...
19841987 def __gt__ (self , other : Self ) -> bool : ...
19851988 def __ge__ (self , other : Self ) -> bool : ...
19861989
19871990
1988- COST = TypeVar ("COST" , bound = Cost )
1991+ COST = TypeVar ("COST" , bound = Comparable )
19891992
19901993
19911994class CostModel (Protocol [COST ]):
19921995 """
19931996 A cost model for an e-graph. Used to determine the cost of an expression based on its structure and the costs of its sub-expressions.
19941997
19951998 Subclass this and implement the methods to create a custom cost model.
1999+
2000+ Additionally, the cost model should guarantee that a term has a no-smaller cost
2001+ than its subterms to avoid cycles in the extracted terms for common case usages.
2002+ For more niche usages, a term can have a cost less than its subterms.
2003+ As long as there is no negative cost cycle, the default extractor is guaranteed to terminate in computing the costs.
2004+ However, the user needs to be careful to guarantee acyclicity in the extracted terms.
19962005 """
19972006
19982007 @abstractmethod
19992008 def fold (self , callable : ExprCallable , children_costs : list [COST ], head_cost : COST ) -> COST :
20002009 """
20012010 The total cost of a term given the cost of the root e-node and its immediate children's total costs.
20022011 """
2012+ raise NotImplementedError
20032013
20042014 @abstractmethod
20052015 def call_cost (self , egraph : EGraph , expr : Expr ) -> COST :
20062016 """
20072017 The cost of an function call (without the cost of children).
20082018 """
2019+ raise NotImplementedError
20092020
20102021 @abstractmethod
20112022 def container_cost (self , egraph : EGraph , expr : Container , element_costs : list [COST ]) -> COST :
20122023 """
20132024 The cost of a container value given the costs of its elements.
20142025 """
2026+ raise NotImplementedError
20152027
20162028 @abstractmethod
20172029 def primitive_cost (self , egraph : EGraph , expr : Primitive ) -> COST :
20182030 """
20192031 The cost of a base value (like a literal or variable).
20202032 """
2033+ raise NotImplementedError
20212034
20222035
2023- class DefaultCostModel (CostModel [int ]):
2024- """
2025- A default cost model for an e-graph.
2036+ class ComparableAdd (Comparable , Protocol ):
2037+ def __add__ (self , other : Self ) -> Self : ...
20262038
2027- Subclass this to extend the default integer cost model.
2039+
2040+ BASE_COST = TypeVar ("BASE_COST" , bound = ComparableAdd )
2041+
2042+
2043+ class BaseCostModel (CostModel [BASE_COST ]):
20282044 """
2045+ Base cost model which provides default implementations for some methods, if the cost can be added and a 0 and 1 exist.
2046+ """
2047+
2048+ @property
2049+ @abstractmethod
2050+ def identity (self ) -> BASE_COST :
2051+ """
2052+ Identity element, such that COST + identity = COST.
2053+
2054+ Usually zero.
2055+ """
2056+ raise NotImplementedError
2057+
2058+ @property
2059+ @abstractmethod
2060+ def unit (self ) -> BASE_COST :
2061+ """
2062+ Unit element, default cost for node with no children, such that COST + unit > COST
2063+ """
2064+ raise NotImplementedError
20292065
2030- def fold (self , callable : ExprCallable , children_costs : list [int ], head_cost : int ) -> int :
2066+ def fold (self , callable : ExprCallable , children_costs : list [BASE_COST ], head_cost : BASE_COST ) -> BASE_COST :
20312067 """
20322068 The total cost of a term given the cost of the root e-node and its immediate children's total costs.
20332069 """
20342070 return sum (children_costs , start = head_cost )
20352071
2036- def container_cost (self , egraph : EGraph , expr : Container , element_costs : list [ int ] ) -> int :
2072+ def call_cost (self , egraph : EGraph , expr : Expr ) -> BASE_COST :
20372073 """
2038- The cost of a container value given the costs of its elements.
2074+ The cost of an function call (without the cost of children).
2075+ """
2076+ return self .unit
20392077
2040- The default cost for containers is just the sum of all the elements inside
2078+ def container_cost (self , egraph : EGraph , expr : Container , element_costs : list [BASE_COST ]) -> BASE_COST :
2079+ """
2080+ The cost of a container value given the costs of its elements.
20412081 """
2042- return sum (element_costs )
2082+ return sum (element_costs , start = self . identity )
20432083
2044- def primitive_cost (self , egraph : EGraph , expr : Primitive ) -> int :
2084+ def primitive_cost (self , egraph : EGraph , expr : Primitive ) -> BASE_COST :
20452085 """
20462086 The cost of a base value (like a literal or variable).
20472087 """
2048- return 1
2088+ return self .unit
2089+
20492090
2091+ class DefaultCostModel (BaseCostModel [int ]):
2092+ """
2093+ A default cost model for an e-graph, which looks up costs set on function calls, or uses 1 as the default cost.
2094+
2095+ Subclass this to extend the default integer cost model.
2096+ """
2097+
2098+ # TODO: Make cost model take identity and unit as args
2099+ identity = 0
2100+ unit = 1
2101+
2102+ # TODO: rename expr cost?
20502103 def call_cost (self , egraph : EGraph , expr : Expr ) -> int :
20512104 """
20522105 The cost of an enode is either the cost set on it, or the cost of the callable, or 1 if neither are set.
@@ -2065,6 +2118,108 @@ def call_cost(self, egraph: EGraph, expr: Expr) -> int:
20652118 return get_callable_cost (callable_fn ) or 1
20662119
20672120
2121+ class ComparableAddSub (ComparableAdd , Protocol ):
2122+ def __sub__ (self , other : Self ) -> Self : ...
2123+
2124+
2125+ DAG_COST = TypeVar ("DAG_COST" , bound = ComparableAddSub )
2126+
2127+
2128+ @total_ordering
2129+ @dataclass
2130+ class GreedyDagCost (Generic [DAG_COST ]):
2131+ expr : BaseExpr
2132+ costs : dict [BaseExpr , DAG_COST ]
2133+ identity : DAG_COST
2134+
2135+ def __eq__ (self , other : object ) -> bool :
2136+ if not isinstance (other , GreedyDagCost ):
2137+ return NotImplemented
2138+ return self .total == other .total
2139+
2140+ def __lt__ (self , other : GreedyDagCost ) -> bool :
2141+ return self .total < other .total
2142+
2143+ @cached_property
2144+ def total (self ) -> DAG_COST :
2145+ return sum (self .costs .values (), start = self .identity )
2146+
2147+ @classmethod
2148+ def from_children (
2149+ cls ,
2150+ expr : BaseExpr ,
2151+ children : list [GreedyDagCost [DAG_COST ]],
2152+ self_and_children : DAG_COST ,
2153+ identity : DAG_COST ,
2154+ ) -> GreedyDagCost [DAG_COST ]:
2155+ """
2156+ Create a GreedyDagCost from the costs of its children and the cost of itself and its children.
2157+
2158+ Make sure to subtract the costs of the children from self_and_children to get the cost of the node itself.
2159+ """
2160+ costs : dict [BaseExpr , DAG_COST ] = {}
2161+ for c in children :
2162+ for k , v in c .costs .items ():
2163+ if k in costs :
2164+ assert costs [k ] == v , f"Conflicting costs for { k } : { costs [k ]} and { v } "
2165+ else :
2166+ costs [k ] = v
2167+ for c in children :
2168+ self_and_children -= c .total
2169+ costs [expr ] = self_and_children
2170+ return cls (expr , costs , identity )
2171+
2172+ def __str__ (self ) -> str :
2173+ return f"GreedyDagCost(total={ self .total } )"
2174+
2175+ def __repr__ (self ) -> str :
2176+ return str (self )
2177+
2178+
2179+ @dataclass
2180+ class GreedyDagCostModel (CostModel [GreedyDagCost [DAG_COST ]]):
2181+ """
2182+ A cost model which will count duplicate nodes only once.
2183+
2184+ Should have similar behavior as https://github.com/egraphs-good/extraction-gym/blob/main/src/extract/greedy_dag.rs
2185+ but implemented as a cost model that will be used with the default extractor.
2186+ """
2187+
2188+ base : BaseCostModel [DAG_COST ]
2189+
2190+ def fold (
2191+ self ,
2192+ callable : ExprCallable ,
2193+ children_costs : list [GreedyDagCost [DAG_COST ]],
2194+ head_cost : GreedyDagCost [DAG_COST ],
2195+ ) -> GreedyDagCost [DAG_COST ]:
2196+ # head cost.total is the same as head_cost.costs[head_cost.expr] because it come from call_cost which always has one cost
2197+ base_fold = self .base .fold (callable , [c .total for c in children_costs ], head_cost .total )
2198+ return GreedyDagCost [DAG_COST ].from_children (head_cost .expr , children_costs , base_fold , self .base .identity )
2199+
2200+ def call_cost (self , egraph : EGraph , expr : Expr ) -> GreedyDagCost [DAG_COST ]:
2201+ """
2202+ The cost of an function call (without the cost of children).
2203+ """
2204+ return GreedyDagCost (expr , {expr : self .base .call_cost (egraph , expr )}, self .base .identity )
2205+
2206+ def container_cost (
2207+ self , egraph : EGraph , expr : Container , element_costs : list [GreedyDagCost [DAG_COST ]]
2208+ ) -> GreedyDagCost [DAG_COST ]:
2209+ """
2210+ The cost of a container value given the costs of its elements.
2211+ """
2212+ base_container_cost = self .base .container_cost (egraph , expr , [c .total for c in element_costs ])
2213+ return GreedyDagCost [DAG_COST ].from_children (expr , element_costs , base_container_cost , self .base .identity )
2214+
2215+ def primitive_cost (self , egraph : EGraph , expr : Primitive ) -> GreedyDagCost [DAG_COST ]:
2216+ """
2217+ The cost of a base value (like a literal or variable).
2218+ """
2219+ cost = self .base .primitive_cost (egraph , expr )
2220+ return GreedyDagCost (expr , {expr : cost }, self .base .identity )
2221+
2222+
20682223def get_callable_cost (fn : ExprCallable ) -> int | None :
20692224 """
20702225 Returns the cost of a callable, if it has one set. Otherwise returns None.
@@ -2119,12 +2274,4 @@ def container_cost(self, tp: str, value: bindings.Value, element_costs: list[COS
21192274 return self .model .container_cost (self .egraph , cast ("Container" , expr ), element_costs )
21202275
21212276 def to_bindings_cost_model (self ) -> bindings .CostModel :
2122- model_tp = type (self .model )
2123- # Use custom costs if we have overriden them, otherwise use None to use the default in Rust for faster performance
2124- fold = self .fold if model_tp .fold is not DefaultCostModel .fold else None
2125- enode_cost = self .enode_cost if model_tp .call_cost is not DefaultCostModel .call_cost else None
2126- container_cost = self .container_cost if model_tp .container_cost is not DefaultCostModel .container_cost else None
2127- base_value_cost = (
2128- self .base_value_cost if model_tp .primitive_cost is not DefaultCostModel .primitive_cost else None
2129- )
2130- return bindings .CostModel (fold , enode_cost , container_cost , base_value_cost )
2277+ return bindings .CostModel (self .fold , self .enode_cost , self .container_cost , self .base_value_cost )
0 commit comments