Skip to content

Commit 3057cc8

Browse files
Fix bug around improperly upcasting during ==
1 parent fddc2dc commit 3057cc8

File tree

3 files changed

+26
-0
lines changed

3 files changed

+26
-0
lines changed

docs/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ _This project uses semantic versioning_
55
## UNRELEASED
66

77
- Change builtins to not evaluate values in egraph and changes facts to compare structural equality instead of using an egraph when converting to a boolean, removing magic context (`EGraph.current` and `Schedule.current`) that was added in release 9.0.0.
8+
- Fix bug that improperly upcasted values for ==
89

910
## 9.0.1 (2025-03-20)
1011

python/egglog/conversion.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,12 @@ def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
149149
decls = _retrieve_conversion_decls()
150150
a_tp = _get_tp(a)
151151
b_tp = _get_tp(b)
152+
# Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
153+
if not (
154+
(isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
155+
or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
156+
):
157+
raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
152158
a_converts_to = {
153159
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
154160
}

python/tests/test_high_level.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,25 @@ def test_eq_false(self):
839839
assert not (i64(3) == 4) # noqa: SIM201
840840

841841

842+
def test_no_upcast_eq():
843+
"""
844+
Verifies that if two items can be upcast to something, calling == on them won't use
845+
equality
846+
"""
847+
848+
class A(Expr):
849+
def __init__(self) -> None: ...
850+
851+
class B(Expr):
852+
def __init__(self) -> None: ...
853+
def __eq__(self, other: B) -> B: ... # type: ignore[override]
854+
855+
converter(A, B, lambda a: B())
856+
857+
assert isinstance(A() == A(), Fact)
858+
assert not isinstance(B() == B(), Fact)
859+
860+
842861
EXAMPLE_FILES = list((pathlib.Path(__file__).parent / "../egglog/examples").glob("*.py"))
843862

844863

0 commit comments

Comments
 (0)