77from collections import defaultdict
88from dataclasses import dataclass , field
99from typing import TYPE_CHECKING , overload
10- from weakref import WeakKeyDictionary
1110
1211from typing_extensions import assert_never
1312
@@ -52,7 +51,7 @@ class EGraphState:
5251 type_ref_to_egg_sort : dict [JustTypeRef , str ] = field (default_factory = dict )
5352
5453 # Cache of egg expressions for converting to egg
55- expr_to_egg_cache : WeakKeyDictionary [ExprDecl , bindings ._Expr ] = field (default_factory = WeakKeyDictionary )
54+ expr_to_egg_cache : dict [ExprDecl , bindings ._Expr ] = field (default_factory = dict )
5655
5756 def copy (self ) -> EGraphState :
5857 """
@@ -147,6 +146,7 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command:
147146 def action_to_egg (self , action : ActionDecl ) -> bindings ._Action :
148147 match action :
149148 case LetDecl (name , typed_expr ):
149+ self .expr_to_egg_cache [VarDecl (name )] = bindings .Var (name )
150150 return bindings .Let (name , self .typed_expr_to_egg (typed_expr ))
151151 case SetDecl (tp , call , rhs ):
152152 self .type_ref_to_egg (tp )
@@ -180,7 +180,7 @@ def fact_to_egg(self, fact: FactDecl) -> bindings._Fact:
180180 self .type_ref_to_egg (tp )
181181 return bindings .Eq ([self ._expr_to_egg (e ) for e in exprs ])
182182 case ExprFactDecl (typed_expr ):
183- return bindings .Fact (self .typed_expr_to_egg (typed_expr ))
183+ return bindings .Fact (self .typed_expr_to_egg (typed_expr , False ))
184184 case _:
185185 assert_never (fact )
186186
@@ -272,10 +272,34 @@ def op_mapping(self) -> dict[str, str]:
272272 if len (v ) == 1
273273 }
274274
275- def typed_expr_to_egg (self , typed_expr_decl : TypedExprDecl ) -> bindings ._Expr :
275+ def typed_expr_to_egg (self , typed_expr_decl : TypedExprDecl , transform_let : bool = True ) -> bindings ._Expr :
276+ # transform all expressions with multiple parents into a let binding, so that less expressions
277+ # are sent to egglog. Only for performance reasons.
278+ if transform_let :
279+ have_multiple_parents = _exprs_multiple_parents (typed_expr_decl )
280+ for expr in reversed (have_multiple_parents ):
281+ self ._transform_let (expr )
282+
276283 self .type_ref_to_egg (typed_expr_decl .tp )
277284 return self ._expr_to_egg (typed_expr_decl .expr )
278285
286+ def _transform_let (self , typed_expr : TypedExprDecl ) -> None :
287+ """
288+ Rewrites this expression as a let binding if it's not already a let binding.
289+ """
290+ name = f"__expr_{ hash (typed_expr )} "
291+ var_decl = VarDecl (name )
292+ if var_decl in self .expr_to_egg_cache :
293+ return
294+ cmd = bindings .ActionCommand (bindings .Let (name , self .typed_expr_to_egg (typed_expr )))
295+ try :
296+ self .egraph .run_program (cmd )
297+ # errors when creating let bindings for things like `(vec-empty)`
298+ except bindings .EggSmolError :
299+ return
300+ self .expr_to_egg_cache [typed_expr .expr ] = bindings .Var (name )
301+ self .expr_to_egg_cache [var_decl ] = bindings .Var (name )
302+
279303 @overload
280304 def _expr_to_egg (self , expr_decl : CallDecl ) -> bindings .Call : ...
281305
@@ -287,12 +311,17 @@ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
287311 Convert an ExprDecl to an egg expression.
288312
289313 Cached using weakrefs to avoid memory leaks.
314+
315+ If transform_let is True, then we will create a let binding for the expression if its children dont contain any unboard variables.
316+
317+ If it's false, it won't.
318+
319+ If it's "not-first", then it will skip trying to create a let binding on the top level, but then will for the rest
290320 """
291321 try :
292322 return self .expr_to_egg_cache [expr_decl ]
293323 except KeyError :
294324 pass
295-
296325 res : bindings ._Expr
297326 match expr_decl :
298327 case VarDecl (name ):
@@ -315,7 +344,7 @@ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
315344 res = bindings .Lit (l )
316345 case CallDecl (ref , args , _):
317346 egg_fn = self .callable_ref_to_egg (ref )
318- egg_args = [self .typed_expr_to_egg (a ) for a in args ]
347+ egg_args = [self .typed_expr_to_egg (a , False ) for a in args ]
319348 res = bindings .Call (egg_fn , egg_args )
320349 case PyObjectDecl (value ):
321350 res = GLOBAL_PY_OBJECT_SORT .store (value )
@@ -324,7 +353,6 @@ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
324353 res = bindings .Call ("unstable-fn" , [bindings .Lit (bindings .String (egg_fn_call .name )), * egg_fn_call .args ])
325354 case _:
326355 assert_never (expr_decl .expr )
327-
328356 self .expr_to_egg_cache [expr_decl ] = res
329357 return res
330358
@@ -344,6 +372,27 @@ def _get_possible_types(self, cls_name: str) -> frozenset[JustTypeRef]:
344372 return frozenset (tp for tp in self .type_ref_to_egg_sort if tp .name == cls_name )
345373
346374
375+ def _exprs_multiple_parents (typed_expr : TypedExprDecl ) -> list [TypedExprDecl ]:
376+ """
377+ Returns all expressions that have multiple parents (a list but semantically just an ordered set).
378+ """
379+ to_traverse = {typed_expr }
380+ traversed = set [TypedExprDecl ]()
381+ traversed_twice = list [TypedExprDecl ]()
382+ while to_traverse :
383+ typed_expr = to_traverse .pop ()
384+ if typed_expr in traversed :
385+ traversed_twice .append (typed_expr )
386+ continue
387+ traversed .add (typed_expr )
388+ expr = typed_expr .expr
389+ if isinstance (expr , CallDecl ):
390+ to_traverse .update (expr .args )
391+ elif isinstance (expr , PartialCallDecl ):
392+ to_traverse .update (expr .call .args )
393+ return traversed_twice
394+
395+
347396def _generate_type_egg_name (ref : JustTypeRef ) -> str :
348397 """
349398 Generates an egg sort name for this type reference by linearizing the type.
0 commit comments