Skip to content

Commit 7f7a52b

Browse files
Remove automatic extraction in builtin evaluation
In #265 (released as 9.0.0) the ability to automatically extract a builtin using a global egraph context was added. This PR removes that feature, requiring all builtins to be in a normalized form. I realized that for #241 we want facts to compare structural equality when converting to a boolean, instead of using the e-graph. Looking at that previous PR, it seems like a mistake to add this implicit context, making things more confusing and opaque with minimal UX improvements.
1 parent 640e8a1 commit 7f7a52b

File tree

14 files changed

+194
-178
lines changed

14 files changed

+194
-178
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ _This project uses semantic versioning_
44

55
## UNRELEASED
66

7+
- Change builtins to not evaluate values in egraph and changes facts to compare structural equality instead of using an egraph when converting to a boolean, removing magic context (`EGraph.current` and `Schedule.current`) that was added in release 9.0.0.
8+
79
## 9.0.1 (2025-03-20)
810

911
- Add missing i64.log2 method to the bindings

docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _num_rule(a: Num, b: Num, c: Num, i: i64, j: i64):
4545
yield rewrite(Num(i) * Num(j)).to(Num(i * j))
4646
4747
egraph.saturate()
48-
egraph.check(eq(expr1).to(expr2))
48+
egraph.check(expr1 == expr2)
4949
egraph.extract(expr1)
5050
```
5151

python/egglog/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from . import config, ipython_magic # noqa: F401
6+
from .bindings import EggSmolError # noqa: F401
67
from .builtins import * # noqa: UP029
78
from .conversion import ConvertError, convert, converter, get_type_args # noqa: F401
89
from .egraph import *

python/egglog/builtins.py

Lines changed: 72 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@
1212

1313
from typing_extensions import TypeVarTuple, Unpack
1414

15-
from . import bindings
1615
from .conversion import convert, converter, get_type_args
1716
from .declarations import *
18-
from .egraph import BaseExpr, BuiltinExpr, EGraph, expr_fact, function, get_current_ruleset, method
19-
from .egraph_state import GLOBAL_PY_OBJECT_SORT
17+
from .egraph import BaseExpr, BuiltinExpr, expr_fact, function, get_current_ruleset, method
2018
from .functionalize import functionalize
2119
from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
2220
from .thunk import Thunk
@@ -32,6 +30,7 @@
3230
"BigRatLike",
3331
"Bool",
3432
"BoolLike",
33+
"BuiltinEvalError",
3534
"Map",
3635
"MapLike",
3736
"MultiSet",
@@ -56,6 +55,17 @@
5655
]
5756

5857

58+
class BuiltinEvalError(Exception):
59+
"""
60+
Raised when an builtin cannot be evaluated into a Python primitive because it is complex.
61+
62+
Try extracting this expression first.
63+
"""
64+
65+
def __str__(self) -> str:
66+
return f"Cannot evaluate builtin expression into a Python primitive. Try extracting this expression first: {super().__str__()}"
67+
68+
5969
class Unit(BuiltinExpr, egg_sort="Unit"):
6070
"""
6171
The unit type. This is used to reprsent if a value exists in the e-graph or not.
@@ -72,8 +82,8 @@ class String(BuiltinExpr):
7282
@method(preserve=True)
7383
def eval(self) -> str:
7484
value = _extract_lit(self)
75-
assert isinstance(value, bindings.String)
76-
return value.value
85+
assert isinstance(value, str)
86+
return value
7787

7888
def __init__(self, value: str) -> None: ...
7989

@@ -97,8 +107,8 @@ class Bool(BuiltinExpr, egg_sort="bool"):
97107
@method(preserve=True)
98108
def eval(self) -> bool:
99109
value = _extract_lit(self)
100-
assert isinstance(value, bindings.Bool)
101-
return value.value
110+
assert isinstance(value, bool)
111+
return value
102112

103113
@method(preserve=True)
104114
def __bool__(self) -> bool:
@@ -132,8 +142,8 @@ class i64(BuiltinExpr): # noqa: N801
132142
@method(preserve=True)
133143
def eval(self) -> int:
134144
value = _extract_lit(self)
135-
assert isinstance(value, bindings.Int)
136-
return value.value
145+
assert isinstance(value, int)
146+
return value
137147

138148
@method(preserve=True)
139149
def __index__(self) -> int:
@@ -251,8 +261,8 @@ class f64(BuiltinExpr): # noqa: N801
251261
@method(preserve=True)
252262
def eval(self) -> float:
253263
value = _extract_lit(self)
254-
assert isinstance(value, bindings.Float)
255-
return value.value
264+
assert isinstance(value, float)
265+
return value
256266

257267
@method(preserve=True)
258268
def __float__(self) -> float:
@@ -340,9 +350,12 @@ def eval(self) -> dict[T, V]:
340350
expr = cast("RuntimeExpr", self)
341351
d = {}
342352
while call.callable != ClassMethodRef("Map", "empty"):
343-
assert call.callable == MethodRef("Map", "insert")
353+
msg = "Map can only be evaluated if it is empty or a series of inserts."
354+
if call.callable != MethodRef("Map", "insert"):
355+
raise BuiltinEvalError(msg)
344356
call_typed, k_typed, v_typed = call.args
345-
assert isinstance(call_typed.expr, CallDecl)
357+
if not isinstance(call_typed.expr, CallDecl):
358+
raise BuiltinEvalError(msg)
346359
k = cast("T", expr.__with_expr__(k_typed))
347360
v = cast("V", expr.__with_expr__(v_typed))
348361
d[k] = v
@@ -404,7 +417,9 @@ class Set(BuiltinExpr, Generic[T]):
404417
@method(preserve=True)
405418
def eval(self) -> set[T]:
406419
call = _extract_call(self)
407-
assert call.callable == InitRef("Set")
420+
if call.callable != InitRef("Set"):
421+
msg = "Set can only be initialized with the Set constructor."
422+
raise BuiltinEvalError(msg)
408423
return {cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args}
409424

410425
@method(preserve=True)
@@ -466,7 +481,9 @@ class MultiSet(BuiltinExpr, Generic[T]):
466481
@method(preserve=True)
467482
def eval(self) -> list[T]:
468483
call = _extract_call(self)
469-
assert call.callable == InitRef("MultiSet")
484+
if call.callable != InitRef("MultiSet"):
485+
msg = "MultiSet can only be initialized with the MultiSet constructor."
486+
raise BuiltinEvalError(msg)
470487
return [cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args]
471488

472489
@method(preserve=True)
@@ -513,11 +530,15 @@ class Rational(BuiltinExpr):
513530
@method(preserve=True)
514531
def eval(self) -> Fraction:
515532
call = _extract_call(self)
516-
assert call.callable == InitRef("Rational")
533+
if call.callable != InitRef("Rational"):
534+
msg = "Rational can only be initialized with the Rational constructor."
535+
raise BuiltinEvalError(msg)
517536

518537
def _to_int(e: TypedExprDecl) -> int:
519538
expr = e.expr
520-
assert isinstance(expr, LitDecl)
539+
if not isinstance(expr, LitDecl):
540+
msg = "Rational can only be initialized with literals"
541+
raise BuiltinEvalError(msg)
521542
assert isinstance(expr.value, int)
522543
return expr.value
523544

@@ -596,9 +617,13 @@ class BigInt(BuiltinExpr):
596617
@method(preserve=True)
597618
def eval(self) -> int:
598619
call = _extract_call(self)
599-
assert call.callable == ClassMethodRef("BigInt", "from_string")
620+
if call.callable != ClassMethodRef("BigInt", "from_string"):
621+
msg = "BigInt can only be initialized with the BigInt constructor."
622+
raise BuiltinEvalError(msg)
600623
(s,) = call.args
601-
assert isinstance(s.expr, LitDecl)
624+
if not isinstance(s.expr, LitDecl):
625+
msg = "BigInt can only be initialized with literals"
626+
raise BuiltinEvalError(msg)
602627
assert isinstance(s.expr.value, str)
603628
return int(s.expr.value)
604629

@@ -717,14 +742,19 @@ class BigRat(BuiltinExpr):
717742
@method(preserve=True)
718743
def eval(self) -> Fraction:
719744
call = _extract_call(self)
720-
assert call.callable == InitRef("BigRat")
745+
if call.callable != InitRef("BigRat"):
746+
msg = "BigRat can only be initialized with the BigRat constructor."
747+
raise BuiltinEvalError(msg)
721748

722749
def _to_fraction(e: TypedExprDecl) -> Fraction:
723750
expr = e.expr
724-
assert isinstance(expr, CallDecl)
725-
assert expr.callable == ClassMethodRef("BigInt", "from_string")
751+
if not isinstance(expr, CallDecl) or expr.callable != ClassMethodRef("BigInt", "from_string"):
752+
msg = "BigRat can only be initialized BigInt strings"
753+
raise BuiltinEvalError(msg)
726754
(s,) = expr.args
727-
assert isinstance(s.expr, LitDecl)
755+
if not isinstance(s.expr, LitDecl):
756+
msg = "BigInt can only be initialized with literals"
757+
raise BuiltinEvalError(msg)
728758
assert isinstance(s.expr.value, str)
729759
return Fraction(s.expr.value)
730760

@@ -821,7 +851,10 @@ def eval(self) -> tuple[T, ...]:
821851
call = _extract_call(self)
822852
if call.callable == ClassMethodRef("Vec", "empty"):
823853
return ()
824-
assert call.callable == InitRef("Vec")
854+
855+
if call.callable != InitRef("Vec"):
856+
msg = "Vec can only be initialized with the Vec constructor."
857+
raise BuiltinEvalError(msg)
825858
return tuple(cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args)
826859

827860
@method(preserve=True)
@@ -889,10 +922,11 @@ def set(self, index: i64Like, value: T) -> Vec[T]: ...
889922
class PyObject(BuiltinExpr):
890923
@method(preserve=True)
891924
def eval(self) -> object:
892-
report = (EGraph.current or EGraph())._run_extract(cast("RuntimeExpr", self), 0)
893-
assert isinstance(report, bindings.Best)
894-
expr = report.termdag.term_to_expr(report.term, bindings.PanicSpan())
895-
return GLOBAL_PY_OBJECT_SORT.load(expr)
925+
expr = cast("RuntimeExpr", self).__egg_typed_expr__.expr
926+
if not isinstance(expr, PyObjectDecl):
927+
msg = "PyObject can only be evaluated if it is a PyObject literal"
928+
raise BuiltinEvalError(msg)
929+
return expr.value
896930

897931
def __init__(self, value: object) -> None: ...
898932

@@ -1027,22 +1061,23 @@ def value_to_annotation(a: object) -> type | None:
10271061
converter(FunctionType, UnstableFn, _convert_function)
10281062

10291063

1030-
def _extract_lit(e: BaseExpr) -> bindings._Literal:
1064+
def _extract_lit(e: BaseExpr) -> LitType:
10311065
"""
10321066
Special case extracting literals to make this faster by using termdag directly.
10331067
"""
1034-
report = (EGraph.current or EGraph())._run_extract(cast("RuntimeExpr", e), 0)
1035-
assert isinstance(report, bindings.Best)
1036-
term = report.term
1037-
assert isinstance(term, bindings.TermLit)
1038-
return term.value
1068+
expr = cast("RuntimeExpr", e).__egg_typed_expr__.expr
1069+
if not isinstance(expr, LitDecl):
1070+
msg = "Expected a literal"
1071+
raise BuiltinEvalError(msg)
1072+
return expr.value
10391073

10401074

10411075
def _extract_call(e: BaseExpr) -> CallDecl:
10421076
"""
10431077
Extracts the call form of an expression
10441078
"""
1045-
extracted = cast("RuntimeExpr", (EGraph.current or EGraph()).extract(e))
1046-
expr = extracted.__egg_typed_expr__.expr
1047-
assert isinstance(expr, CallDecl)
1079+
expr = cast("RuntimeExpr", e).__egg_typed_expr__.expr
1080+
if not isinstance(expr, CallDecl):
1081+
msg = "Expected a call expression"
1082+
raise BuiltinEvalError(msg)
10481083
return expr

python/egglog/egraph.py

Lines changed: 6 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import inspect
55
import pathlib
66
import tempfile
7-
from collections.abc import Callable, Generator, Iterable, Iterator
7+
from collections.abc import Callable, Generator, Iterable
88
from contextvars import ContextVar
99
from dataclasses import InitVar, dataclass, field
1010
from functools import partial
@@ -17,7 +17,6 @@
1717
Generic,
1818
Literal,
1919
Never,
20-
Protocol,
2120
TypeAlias,
2221
TypedDict,
2322
TypeVar,
@@ -85,7 +84,6 @@
8584
"set_",
8685
"simplify",
8786
"subsume",
88-
"try_evaling",
8987
"union",
9088
"unstable_combine_rulesets",
9189
"var",
@@ -847,7 +845,6 @@ class EGraph:
847845
Can run actions, check facts, run schedules, or extract minimal cost expressions.
848846
"""
849847

850-
current: ClassVar[EGraph | None] = None
851848
seminaive: InitVar[bool] = True
852849
save_egglog_string: InitVar[bool] = False
853850

@@ -1200,16 +1197,6 @@ def to_json() -> str:
12001197
if visualize:
12011198
VisualizerWidget(egraphs=egraphs).display_or_open()
12021199

1203-
@contextlib.contextmanager
1204-
def set_current(self) -> Iterator[None]:
1205-
"""
1206-
Context manager that will set the current egraph. It will be set back after.
1207-
"""
1208-
prev_current = EGraph.current
1209-
EGraph.current = self
1210-
yield
1211-
EGraph.current = prev_current
1212-
12131200
@property
12141201
def _egraph(self) -> bindings.EGraph:
12151202
return self._state.egraph
@@ -1303,8 +1290,6 @@ class Schedule(DelayedDeclerations):
13031290
A composition of some rulesets, either composing them sequentially, running them repeatedly, running them till saturation, or running until some facts are met
13041291
"""
13051292

1306-
current: ClassVar[Schedule | None] = None
1307-
13081293
# Defer declerations so that we can have rule generators that used not yet defined yet
13091294
schedule: ScheduleDecl
13101295

@@ -1332,16 +1317,6 @@ def __add__(self, other: Schedule) -> Schedule:
13321317
"""
13331318
return Schedule(Thunk.fn(Declarations.create, self, other), SequenceDecl((self.schedule, other.schedule)))
13341319

1335-
@contextlib.contextmanager
1336-
def set_current(self) -> Iterator[None]:
1337-
"""
1338-
Context manager that will set the current schedule. It will be set back after
1339-
"""
1340-
prev_current = Schedule.current
1341-
Schedule.current = self
1342-
yield
1343-
Schedule.current = prev_current
1344-
13451320

13461321
@dataclass
13471322
class Ruleset(Schedule):
@@ -1488,9 +1463,12 @@ def __repr__(self) -> str:
14881463

14891464
def __bool__(self) -> bool:
14901465
"""
1491-
Returns True if the two sides of an equality are equal in the egraph or the expression is in the egraph.
1466+
Returns True if the two sides of an equality are structurally equal.
14921467
"""
1493-
return (EGraph.current or EGraph()).check_bool(self)
1468+
if not isinstance(self.fact, EqDecl):
1469+
msg = "Can only check equality facts"
1470+
raise TypeError(msg)
1471+
return self.fact.left == self.fact.right
14941472

14951473

14961474
@dataclass
@@ -1876,34 +1854,3 @@ def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
18761854
yield
18771855
finally:
18781856
_CURRENT_RULESET.reset(token)
1879-
1880-
1881-
T_co = TypeVar("T_co", covariant=True)
1882-
1883-
1884-
class _EvalsTo(Protocol, Generic[T_co]):
1885-
def eval(self) -> T_co: ...
1886-
1887-
1888-
def try_evaling(schedule: Schedule, expr: Expr, prim_expr: _EvalsTo[T]) -> T:
1889-
"""
1890-
Try evaling the expression that will result in a primitive expression being fill.
1891-
if it fails, display the egraph and raise an error.
1892-
"""
1893-
egraph = EGraph.current or EGraph()
1894-
with egraph.set_current():
1895-
try:
1896-
return prim_expr.eval()
1897-
except BaseException: # noqa: S110
1898-
pass
1899-
# If this primitive doesn't exist in the egraph, we need to try to create it by
1900-
# registering the expression and running the schedule
1901-
egraph.register(expr)
1902-
egraph.run(Schedule.current or schedule)
1903-
try:
1904-
with egraph.set_current():
1905-
return prim_expr.eval()
1906-
except BaseException as e:
1907-
# egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
1908-
e.add_note(f"Cannot evaluate {egraph.extract(expr)}")
1909-
raise

0 commit comments

Comments
 (0)