@@ -2068,6 +2068,9 @@ def __gt__(self, other: Self) -> bool:
20682068 def __ge__ (self , other : Self ) -> bool :
20692069 return self .total >= other .total
20702070
2071+ def __hash__ (self ) -> int :
2072+ return hash (self .total )
2073+
20712074
20722075@dataclass
20732076class GreedyDagCostModel (CostModel [GreedyDagCost [DAG_COST ]]):
@@ -2126,21 +2129,35 @@ class _CostModel(Generic[COST]):
21262129
21272130 model : CostModel [COST ]
21282131 egraph : EGraph
2132+ enode_cost_results : dict [tuple [str , tuple [bindings .Value , ...]], int ] = field (default_factory = dict )
2133+ enode_cost_expressions : list [RuntimeExpr ] = field (default_factory = list )
2134+ fold_results : dict [tuple [int , tuple [COST , ...]], COST ] = field (default_factory = dict )
2135+ base_value_cost_results : dict [tuple [str , bindings .Value ], COST ] = field (default_factory = dict )
2136+ container_cost_results : dict [tuple [str , bindings .Value , tuple [COST , ...]], COST ] = field (default_factory = dict )
21292137
21302138 def call_model (self , expr : RuntimeExpr , children_costs : list [COST ]) -> COST :
2131- res = self .model (self .egraph , cast ("BaseExpr" , expr ), children_costs )
2132- if __debug__ :
2133- for c in children_costs :
2134- if res <= c :
2135- msg = f"Cost model { self .model } produced a cost { res } less than or equal to a child cost { c } for { expr } "
2136- raise ValueError (msg )
2137- return res
2139+ return self .model (self .egraph , cast ("BaseExpr" , expr ), children_costs )
2140+ # if __debug__:
2141+ # for c in children_costs:
2142+ # if res <= c:
2143+ # msg = f"Cost model {self.model} produced a cost {res} less than or equal to a child cost {c} for {expr}"
2144+ # raise ValueError(msg)
2145+
2146+ def fold (self , _fn : str , index : int , children_costs : list [COST ]) -> COST :
2147+ try :
2148+ return self .fold_results [(index , tuple (children_costs ))]
2149+ except KeyError :
2150+ pass
21382151
2139- def fold ( self , _fn : str , head_cost : RuntimeExpr , children_costs : list [ COST ]) -> COST :
2140- return self .call_model (head_cost , children_costs )
2152+ expr = self . enode_cost_expressions [ index ]
2153+ return self .call_model (expr , children_costs )
21412154
21422155 # enode cost is only ever called right before fold, for the head_cost
2143- def enode_cost (self , name : str , args : list [bindings .Value ]) -> RuntimeExpr :
2156+ def enode_cost (self , name : str , args : list [bindings .Value ]) -> int :
2157+ try :
2158+ return self .enode_cost_results [(name , tuple (args ))]
2159+ except KeyError :
2160+ pass
21442161 (callable_ref ,) = self .egraph ._state .egg_fn_to_callable_refs [name ]
21452162 signature = self .egraph .__egg_decls__ .get_callable_decl (callable_ref ).signature
21462163 assert isinstance (signature , FunctionSignature )
@@ -2149,26 +2166,42 @@ def enode_cost(self, name: str, args: list[bindings.Value]) -> RuntimeExpr:
21492166 for (arg , tp ) in zip (args , signature .arg_types , strict = True )
21502167 ]
21512168 res_type = signature .semantic_return_type .to_just ()
2152- return RuntimeExpr .__from_values__ (
2169+ res = RuntimeExpr .__from_values__ (
21532170 self .egraph .__egg_decls__ ,
21542171 TypedExprDecl (res_type , CallDecl (callable_ref , tuple (arg_exprs ))),
21552172 )
2173+ index = len (self .enode_cost_expressions )
2174+ self .enode_cost_expressions .append (res )
2175+ self .enode_cost_results [(name , tuple (args ))] = index
2176+ return index
21562177
21572178 def base_value_cost (self , tp : str , value : bindings .Value ) -> COST :
2179+ try :
2180+ return self .base_value_cost_results [(tp , value )]
2181+ except KeyError :
2182+ pass
21582183 type_ref = self .egraph ._state .egg_sort_to_type_ref [tp ]
21592184 expr = RuntimeExpr .__from_values__ (
21602185 self .egraph .__egg_decls__ ,
21612186 TypedExprDecl (type_ref , self .egraph ._state .value_to_expr (type_ref , value )),
21622187 )
2163- return self .call_model (expr , [])
2188+ res = self .call_model (expr , [])
2189+ self .base_value_cost_results [(tp , value )] = res
2190+ return res
21642191
21652192 def container_cost (self , tp : str , value : bindings .Value , element_costs : list [COST ]) -> COST :
2193+ try :
2194+ return self .container_cost_results [(tp , value , tuple (element_costs ))]
2195+ except KeyError :
2196+ pass
21662197 type_ref = self .egraph ._state .egg_sort_to_type_ref [tp ]
21672198 expr = RuntimeExpr .__from_values__ (
21682199 self .egraph .__egg_decls__ ,
21692200 TypedExprDecl (type_ref , self .egraph ._state .value_to_expr (type_ref , value )),
21702201 )
2171- return self .call_model (expr , element_costs )
2202+ res = self .call_model (expr , element_costs )
2203+ self .container_cost_results [(tp , value , tuple (element_costs ))] = res
2204+ return res
21722205
2173- def to_bindings_cost_model (self ) -> bindings .CostModel [COST , RuntimeExpr ]:
2206+ def to_bindings_cost_model (self ) -> bindings .CostModel [COST , int ]:
21742207 return bindings .CostModel (self .fold , self .enode_cost , self .container_cost , self .base_value_cost )
0 commit comments