Skip to content

Commit 468172b

Browse files
Cache cost models
1 parent fb0ffd0 commit 468172b

File tree

1 file changed

+47
-14
lines changed

1 file changed

+47
-14
lines changed

python/egglog/egraph.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2068,6 +2068,9 @@ def __gt__(self, other: Self) -> bool:
20682068
def __ge__(self, other: Self) -> bool:
20692069
return self.total >= other.total
20702070

2071+
def __hash__(self) -> int:
2072+
return hash(self.total)
2073+
20712074

20722075
@dataclass
20732076
class GreedyDagCostModel(CostModel[GreedyDagCost[DAG_COST]]):
@@ -2126,21 +2129,35 @@ class _CostModel(Generic[COST]):
21262129

21272130
model: CostModel[COST]
21282131
egraph: EGraph
2132+
enode_cost_results: dict[tuple[str, tuple[bindings.Value, ...]], int] = field(default_factory=dict)
2133+
enode_cost_expressions: list[RuntimeExpr] = field(default_factory=list)
2134+
fold_results: dict[tuple[int, tuple[COST, ...]], COST] = field(default_factory=dict)
2135+
base_value_cost_results: dict[tuple[str, bindings.Value], COST] = field(default_factory=dict)
2136+
container_cost_results: dict[tuple[str, bindings.Value, tuple[COST, ...]], COST] = field(default_factory=dict)
21292137

21302138
def call_model(self, expr: RuntimeExpr, children_costs: list[COST]) -> COST:
2131-
res = self.model(self.egraph, cast("BaseExpr", expr), children_costs)
2132-
if __debug__:
2133-
for c in children_costs:
2134-
if res <= c:
2135-
msg = f"Cost model {self.model} produced a cost {res} less than or equal to a child cost {c} for {expr}"
2136-
raise ValueError(msg)
2137-
return res
2139+
return self.model(self.egraph, cast("BaseExpr", expr), children_costs)
2140+
# if __debug__:
2141+
# for c in children_costs:
2142+
# if res <= c:
2143+
# msg = f"Cost model {self.model} produced a cost {res} less than or equal to a child cost {c} for {expr}"
2144+
# raise ValueError(msg)
2145+
2146+
def fold(self, _fn: str, index: int, children_costs: list[COST]) -> COST:
2147+
try:
2148+
return self.fold_results[(index, tuple(children_costs))]
2149+
except KeyError:
2150+
pass
21382151

2139-
def fold(self, _fn: str, head_cost: RuntimeExpr, children_costs: list[COST]) -> COST:
2140-
return self.call_model(head_cost, children_costs)
2152+
expr = self.enode_cost_expressions[index]
2153+
return self.call_model(expr, children_costs)
21412154

21422155
# enode cost is only ever called right before fold, for the head_cost
2143-
def enode_cost(self, name: str, args: list[bindings.Value]) -> RuntimeExpr:
2156+
def enode_cost(self, name: str, args: list[bindings.Value]) -> int:
2157+
try:
2158+
return self.enode_cost_results[(name, tuple(args))]
2159+
except KeyError:
2160+
pass
21442161
(callable_ref,) = self.egraph._state.egg_fn_to_callable_refs[name]
21452162
signature = self.egraph.__egg_decls__.get_callable_decl(callable_ref).signature
21462163
assert isinstance(signature, FunctionSignature)
@@ -2149,26 +2166,42 @@ def enode_cost(self, name: str, args: list[bindings.Value]) -> RuntimeExpr:
21492166
for (arg, tp) in zip(args, signature.arg_types, strict=True)
21502167
]
21512168
res_type = signature.semantic_return_type.to_just()
2152-
return RuntimeExpr.__from_values__(
2169+
res = RuntimeExpr.__from_values__(
21532170
self.egraph.__egg_decls__,
21542171
TypedExprDecl(res_type, CallDecl(callable_ref, tuple(arg_exprs))),
21552172
)
2173+
index = len(self.enode_cost_expressions)
2174+
self.enode_cost_expressions.append(res)
2175+
self.enode_cost_results[(name, tuple(args))] = index
2176+
return index
21562177

21572178
def base_value_cost(self, tp: str, value: bindings.Value) -> COST:
2179+
try:
2180+
return self.base_value_cost_results[(tp, value)]
2181+
except KeyError:
2182+
pass
21582183
type_ref = self.egraph._state.egg_sort_to_type_ref[tp]
21592184
expr = RuntimeExpr.__from_values__(
21602185
self.egraph.__egg_decls__,
21612186
TypedExprDecl(type_ref, self.egraph._state.value_to_expr(type_ref, value)),
21622187
)
2163-
return self.call_model(expr, [])
2188+
res = self.call_model(expr, [])
2189+
self.base_value_cost_results[(tp, value)] = res
2190+
return res
21642191

21652192
def container_cost(self, tp: str, value: bindings.Value, element_costs: list[COST]) -> COST:
2193+
try:
2194+
return self.container_cost_results[(tp, value, tuple(element_costs))]
2195+
except KeyError:
2196+
pass
21662197
type_ref = self.egraph._state.egg_sort_to_type_ref[tp]
21672198
expr = RuntimeExpr.__from_values__(
21682199
self.egraph.__egg_decls__,
21692200
TypedExprDecl(type_ref, self.egraph._state.value_to_expr(type_ref, value)),
21702201
)
2171-
return self.call_model(expr, element_costs)
2202+
res = self.call_model(expr, element_costs)
2203+
self.container_cost_results[(tp, value, tuple(element_costs))] = res
2204+
return res
21722205

2173-
def to_bindings_cost_model(self) -> bindings.CostModel[COST, RuntimeExpr]:
2206+
def to_bindings_cost_model(self) -> bindings.CostModel[COST, int]:
21742207
return bindings.CostModel(self.fold, self.enode_cost, self.container_cost, self.base_value_cost)

0 commit comments

Comments
 (0)