Skip to content

Commit c312526

Browse files
Merge pull request #282 from egraphs-good/less-magic
Remove automatic extraction in builtin evaluation
2 parents 640e8a1 + 3057cc8 commit c312526

15 files changed

+220
-178
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/changelog.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ _This project uses semantic versioning_
44

55
## UNRELEASED
66

7+
- Change builtins to not evaluate values in egraph and changes facts to compare structural equality instead of using an egraph when converting to a boolean, removing magic context (`EGraph.current` and `Schedule.current`) that was added in release 9.0.0.
8+
- Fix bug that improperly upcasted values for ==
9+
710
## 9.0.1 (2025-03-20)
811

912
- Add missing i64.log2 method to the bindings

docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _num_rule(a: Num, b: Num, c: Num, i: i64, j: i64):
4545
yield rewrite(Num(i) * Num(j)).to(Num(i * j))
4646
4747
egraph.saturate()
48-
egraph.check(eq(expr1).to(expr2))
48+
egraph.check(expr1 == expr2)
4949
egraph.extract(expr1)
5050
```
5151

python/egglog/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from . import config, ipython_magic # noqa: F401
6+
from .bindings import EggSmolError # noqa: F401
67
from .builtins import * # noqa: UP029
78
from .conversion import ConvertError, convert, converter, get_type_args # noqa: F401
89
from .egraph import *

python/egglog/builtins.py

Lines changed: 72 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@
1212

1313
from typing_extensions import TypeVarTuple, Unpack
1414

15-
from . import bindings
1615
from .conversion import convert, converter, get_type_args
1716
from .declarations import *
18-
from .egraph import BaseExpr, BuiltinExpr, EGraph, expr_fact, function, get_current_ruleset, method
19-
from .egraph_state import GLOBAL_PY_OBJECT_SORT
17+
from .egraph import BaseExpr, BuiltinExpr, expr_fact, function, get_current_ruleset, method
2018
from .functionalize import functionalize
2119
from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
2220
from .thunk import Thunk
@@ -32,6 +30,7 @@
3230
"BigRatLike",
3331
"Bool",
3432
"BoolLike",
33+
"BuiltinEvalError",
3534
"Map",
3635
"MapLike",
3736
"MultiSet",
@@ -56,6 +55,17 @@
5655
]
5756

5857

58+
class BuiltinEvalError(Exception):
59+
"""
60+
Raised when an builtin cannot be evaluated into a Python primitive because it is complex.
61+
62+
Try extracting this expression first.
63+
"""
64+
65+
def __str__(self) -> str:
66+
return f"Cannot evaluate builtin expression into a Python primitive. Try extracting this expression first: {super().__str__()}"
67+
68+
5969
class Unit(BuiltinExpr, egg_sort="Unit"):
6070
"""
6171
The unit type. This is used to reprsent if a value exists in the e-graph or not.
@@ -72,8 +82,8 @@ class String(BuiltinExpr):
7282
@method(preserve=True)
7383
def eval(self) -> str:
7484
value = _extract_lit(self)
75-
assert isinstance(value, bindings.String)
76-
return value.value
85+
assert isinstance(value, str)
86+
return value
7787

7888
def __init__(self, value: str) -> None: ...
7989

@@ -97,8 +107,8 @@ class Bool(BuiltinExpr, egg_sort="bool"):
97107
@method(preserve=True)
98108
def eval(self) -> bool:
99109
value = _extract_lit(self)
100-
assert isinstance(value, bindings.Bool)
101-
return value.value
110+
assert isinstance(value, bool)
111+
return value
102112

103113
@method(preserve=True)
104114
def __bool__(self) -> bool:
@@ -132,8 +142,8 @@ class i64(BuiltinExpr): # noqa: N801
132142
@method(preserve=True)
133143
def eval(self) -> int:
134144
value = _extract_lit(self)
135-
assert isinstance(value, bindings.Int)
136-
return value.value
145+
assert isinstance(value, int)
146+
return value
137147

138148
@method(preserve=True)
139149
def __index__(self) -> int:
@@ -251,8 +261,8 @@ class f64(BuiltinExpr): # noqa: N801
251261
@method(preserve=True)
252262
def eval(self) -> float:
253263
value = _extract_lit(self)
254-
assert isinstance(value, bindings.Float)
255-
return value.value
264+
assert isinstance(value, float)
265+
return value
256266

257267
@method(preserve=True)
258268
def __float__(self) -> float:
@@ -340,9 +350,12 @@ def eval(self) -> dict[T, V]:
340350
expr = cast("RuntimeExpr", self)
341351
d = {}
342352
while call.callable != ClassMethodRef("Map", "empty"):
343-
assert call.callable == MethodRef("Map", "insert")
353+
msg = "Map can only be evaluated if it is empty or a series of inserts."
354+
if call.callable != MethodRef("Map", "insert"):
355+
raise BuiltinEvalError(msg)
344356
call_typed, k_typed, v_typed = call.args
345-
assert isinstance(call_typed.expr, CallDecl)
357+
if not isinstance(call_typed.expr, CallDecl):
358+
raise BuiltinEvalError(msg)
346359
k = cast("T", expr.__with_expr__(k_typed))
347360
v = cast("V", expr.__with_expr__(v_typed))
348361
d[k] = v
@@ -404,7 +417,9 @@ class Set(BuiltinExpr, Generic[T]):
404417
@method(preserve=True)
405418
def eval(self) -> set[T]:
406419
call = _extract_call(self)
407-
assert call.callable == InitRef("Set")
420+
if call.callable != InitRef("Set"):
421+
msg = "Set can only be initialized with the Set constructor."
422+
raise BuiltinEvalError(msg)
408423
return {cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args}
409424

410425
@method(preserve=True)
@@ -466,7 +481,9 @@ class MultiSet(BuiltinExpr, Generic[T]):
466481
@method(preserve=True)
467482
def eval(self) -> list[T]:
468483
call = _extract_call(self)
469-
assert call.callable == InitRef("MultiSet")
484+
if call.callable != InitRef("MultiSet"):
485+
msg = "MultiSet can only be initialized with the MultiSet constructor."
486+
raise BuiltinEvalError(msg)
470487
return [cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args]
471488

472489
@method(preserve=True)
@@ -513,11 +530,15 @@ class Rational(BuiltinExpr):
513530
@method(preserve=True)
514531
def eval(self) -> Fraction:
515532
call = _extract_call(self)
516-
assert call.callable == InitRef("Rational")
533+
if call.callable != InitRef("Rational"):
534+
msg = "Rational can only be initialized with the Rational constructor."
535+
raise BuiltinEvalError(msg)
517536

518537
def _to_int(e: TypedExprDecl) -> int:
519538
expr = e.expr
520-
assert isinstance(expr, LitDecl)
539+
if not isinstance(expr, LitDecl):
540+
msg = "Rational can only be initialized with literals"
541+
raise BuiltinEvalError(msg)
521542
assert isinstance(expr.value, int)
522543
return expr.value
523544

@@ -596,9 +617,13 @@ class BigInt(BuiltinExpr):
596617
@method(preserve=True)
597618
def eval(self) -> int:
598619
call = _extract_call(self)
599-
assert call.callable == ClassMethodRef("BigInt", "from_string")
620+
if call.callable != ClassMethodRef("BigInt", "from_string"):
621+
msg = "BigInt can only be initialized with the BigInt constructor."
622+
raise BuiltinEvalError(msg)
600623
(s,) = call.args
601-
assert isinstance(s.expr, LitDecl)
624+
if not isinstance(s.expr, LitDecl):
625+
msg = "BigInt can only be initialized with literals"
626+
raise BuiltinEvalError(msg)
602627
assert isinstance(s.expr.value, str)
603628
return int(s.expr.value)
604629

@@ -717,14 +742,19 @@ class BigRat(BuiltinExpr):
717742
@method(preserve=True)
718743
def eval(self) -> Fraction:
719744
call = _extract_call(self)
720-
assert call.callable == InitRef("BigRat")
745+
if call.callable != InitRef("BigRat"):
746+
msg = "BigRat can only be initialized with the BigRat constructor."
747+
raise BuiltinEvalError(msg)
721748

722749
def _to_fraction(e: TypedExprDecl) -> Fraction:
723750
expr = e.expr
724-
assert isinstance(expr, CallDecl)
725-
assert expr.callable == ClassMethodRef("BigInt", "from_string")
751+
if not isinstance(expr, CallDecl) or expr.callable != ClassMethodRef("BigInt", "from_string"):
752+
msg = "BigRat can only be initialized BigInt strings"
753+
raise BuiltinEvalError(msg)
726754
(s,) = expr.args
727-
assert isinstance(s.expr, LitDecl)
755+
if not isinstance(s.expr, LitDecl):
756+
msg = "BigInt can only be initialized with literals"
757+
raise BuiltinEvalError(msg)
728758
assert isinstance(s.expr.value, str)
729759
return Fraction(s.expr.value)
730760

@@ -821,7 +851,10 @@ def eval(self) -> tuple[T, ...]:
821851
call = _extract_call(self)
822852
if call.callable == ClassMethodRef("Vec", "empty"):
823853
return ()
824-
assert call.callable == InitRef("Vec")
854+
855+
if call.callable != InitRef("Vec"):
856+
msg = "Vec can only be initialized with the Vec constructor."
857+
raise BuiltinEvalError(msg)
825858
return tuple(cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args)
826859

827860
@method(preserve=True)
@@ -889,10 +922,11 @@ def set(self, index: i64Like, value: T) -> Vec[T]: ...
889922
class PyObject(BuiltinExpr):
890923
@method(preserve=True)
891924
def eval(self) -> object:
892-
report = (EGraph.current or EGraph())._run_extract(cast("RuntimeExpr", self), 0)
893-
assert isinstance(report, bindings.Best)
894-
expr = report.termdag.term_to_expr(report.term, bindings.PanicSpan())
895-
return GLOBAL_PY_OBJECT_SORT.load(expr)
925+
expr = cast("RuntimeExpr", self).__egg_typed_expr__.expr
926+
if not isinstance(expr, PyObjectDecl):
927+
msg = "PyObject can only be evaluated if it is a PyObject literal"
928+
raise BuiltinEvalError(msg)
929+
return expr.value
896930

897931
def __init__(self, value: object) -> None: ...
898932

@@ -1027,22 +1061,23 @@ def value_to_annotation(a: object) -> type | None:
10271061
converter(FunctionType, UnstableFn, _convert_function)
10281062

10291063

1030-
def _extract_lit(e: BaseExpr) -> bindings._Literal:
1064+
def _extract_lit(e: BaseExpr) -> LitType:
10311065
"""
10321066
Special case extracting literals to make this faster by using termdag directly.
10331067
"""
1034-
report = (EGraph.current or EGraph())._run_extract(cast("RuntimeExpr", e), 0)
1035-
assert isinstance(report, bindings.Best)
1036-
term = report.term
1037-
assert isinstance(term, bindings.TermLit)
1038-
return term.value
1068+
expr = cast("RuntimeExpr", e).__egg_typed_expr__.expr
1069+
if not isinstance(expr, LitDecl):
1070+
msg = "Expected a literal"
1071+
raise BuiltinEvalError(msg)
1072+
return expr.value
10391073

10401074

10411075
def _extract_call(e: BaseExpr) -> CallDecl:
10421076
"""
10431077
Extracts the call form of an expression
10441078
"""
1045-
extracted = cast("RuntimeExpr", (EGraph.current or EGraph()).extract(e))
1046-
expr = extracted.__egg_typed_expr__.expr
1047-
assert isinstance(expr, CallDecl)
1079+
expr = cast("RuntimeExpr", e).__egg_typed_expr__.expr
1080+
if not isinstance(expr, CallDecl):
1081+
msg = "Expected a call expression"
1082+
raise BuiltinEvalError(msg)
10481083
return expr

python/egglog/conversion.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,12 @@ def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
149149
decls = _retrieve_conversion_decls()
150150
a_tp = _get_tp(a)
151151
b_tp = _get_tp(b)
152+
# Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
153+
if not (
154+
(isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
155+
or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
156+
):
157+
raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
152158
a_converts_to = {
153159
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
154160
}

0 commit comments

Comments
 (0)