Skip to content

Commit bdfee08

Browse files
Merge pull request #179 from egraphs-good/faster-let
Perform CSE when converting to egglog expression
2 parents cb485a2 + 6348cb7 commit bdfee08

File tree

1 file changed

+56
-7
lines changed

1 file changed

+56
-7
lines changed

python/egglog/egraph_state.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from collections import defaultdict
88
from dataclasses import dataclass, field
99
from typing import TYPE_CHECKING, overload
10-
from weakref import WeakKeyDictionary
1110

1211
from typing_extensions import assert_never
1312

@@ -52,7 +51,7 @@ class EGraphState:
5251
type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
5352

5453
# Cache of egg expressions for converting to egg
55-
expr_to_egg_cache: WeakKeyDictionary[ExprDecl, bindings._Expr] = field(default_factory=WeakKeyDictionary)
54+
expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)
5655

5756
def copy(self) -> EGraphState:
5857
"""
@@ -147,6 +146,7 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command:
147146
def action_to_egg(self, action: ActionDecl) -> bindings._Action:
148147
match action:
149148
case LetDecl(name, typed_expr):
149+
self.expr_to_egg_cache[VarDecl(name)] = bindings.Var(name)
150150
return bindings.Let(name, self.typed_expr_to_egg(typed_expr))
151151
case SetDecl(tp, call, rhs):
152152
self.type_ref_to_egg(tp)
@@ -180,7 +180,7 @@ def fact_to_egg(self, fact: FactDecl) -> bindings._Fact:
180180
self.type_ref_to_egg(tp)
181181
return bindings.Eq([self._expr_to_egg(e) for e in exprs])
182182
case ExprFactDecl(typed_expr):
183-
return bindings.Fact(self.typed_expr_to_egg(typed_expr))
183+
return bindings.Fact(self.typed_expr_to_egg(typed_expr, False))
184184
case _:
185185
assert_never(fact)
186186

@@ -272,10 +272,34 @@ def op_mapping(self) -> dict[str, str]:
272272
if len(v) == 1
273273
}
274274

275-
def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl) -> bindings._Expr:
275+
def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl, transform_let: bool = True) -> bindings._Expr:
276+
# transform all expressions with multiple parents into a let binding, so that less expressions
277+
# are sent to egglog. Only for performance reasons.
278+
if transform_let:
279+
have_multiple_parents = _exprs_multiple_parents(typed_expr_decl)
280+
for expr in reversed(have_multiple_parents):
281+
self._transform_let(expr)
282+
276283
self.type_ref_to_egg(typed_expr_decl.tp)
277284
return self._expr_to_egg(typed_expr_decl.expr)
278285

286+
def _transform_let(self, typed_expr: TypedExprDecl) -> None:
287+
"""
288+
Rewrites this expression as a let binding if it's not already a let binding.
289+
"""
290+
name = f"__expr_{hash(typed_expr)}"
291+
var_decl = VarDecl(name)
292+
if var_decl in self.expr_to_egg_cache:
293+
return
294+
cmd = bindings.ActionCommand(bindings.Let(name, self.typed_expr_to_egg(typed_expr)))
295+
try:
296+
self.egraph.run_program(cmd)
297+
# errors when creating let bindings for things like `(vec-empty)`
298+
except bindings.EggSmolError:
299+
return
300+
self.expr_to_egg_cache[typed_expr.expr] = bindings.Var(name)
301+
self.expr_to_egg_cache[var_decl] = bindings.Var(name)
302+
279303
@overload
280304
def _expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
281305

@@ -287,12 +311,17 @@ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
287311
Convert an ExprDecl to an egg expression.
288312
289313
Cached using weakrefs to avoid memory leaks.
314+
315+
If transform_let is True, then we will create a let binding for the expression if its children dont contain any unboard variables.
316+
317+
If it's false, it won't.
318+
319+
If it's "not-first", then it will skip trying to create a let binding on the top level, but then will for the rest
290320
"""
291321
try:
292322
return self.expr_to_egg_cache[expr_decl]
293323
except KeyError:
294324
pass
295-
296325
res: bindings._Expr
297326
match expr_decl:
298327
case VarDecl(name):
@@ -315,7 +344,7 @@ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
315344
res = bindings.Lit(l)
316345
case CallDecl(ref, args, _):
317346
egg_fn = self.callable_ref_to_egg(ref)
318-
egg_args = [self.typed_expr_to_egg(a) for a in args]
347+
egg_args = [self.typed_expr_to_egg(a, False) for a in args]
319348
res = bindings.Call(egg_fn, egg_args)
320349
case PyObjectDecl(value):
321350
res = GLOBAL_PY_OBJECT_SORT.store(value)
@@ -324,7 +353,6 @@ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
324353
res = bindings.Call("unstable-fn", [bindings.Lit(bindings.String(egg_fn_call.name)), *egg_fn_call.args])
325354
case _:
326355
assert_never(expr_decl.expr)
327-
328356
self.expr_to_egg_cache[expr_decl] = res
329357
return res
330358

@@ -344,6 +372,27 @@ def _get_possible_types(self, cls_name: str) -> frozenset[JustTypeRef]:
344372
return frozenset(tp for tp in self.type_ref_to_egg_sort if tp.name == cls_name)
345373

346374

375+
def _exprs_multiple_parents(typed_expr: TypedExprDecl) -> list[TypedExprDecl]:
376+
"""
377+
Returns all expressions that have multiple parents (a list but semantically just an ordered set).
378+
"""
379+
to_traverse = {typed_expr}
380+
traversed = set[TypedExprDecl]()
381+
traversed_twice = list[TypedExprDecl]()
382+
while to_traverse:
383+
typed_expr = to_traverse.pop()
384+
if typed_expr in traversed:
385+
traversed_twice.append(typed_expr)
386+
continue
387+
traversed.add(typed_expr)
388+
expr = typed_expr.expr
389+
if isinstance(expr, CallDecl):
390+
to_traverse.update(expr.args)
391+
elif isinstance(expr, PartialCallDecl):
392+
to_traverse.update(expr.call.args)
393+
return traversed_twice
394+
395+
347396
def _generate_type_egg_name(ref: JustTypeRef) -> str:
348397
"""
349398
Generates an egg sort name for this type reference by linearizing the type.

0 commit comments

Comments
 (0)