Skip to content

Commit 84499b9

Browse files
Add support for binary operations running on preserved operations
1 parent 8651b1a commit 84499b9

File tree

2 files changed

+40
-6
lines changed

2 files changed

+40
-6
lines changed

python/egglog/runtime.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -793,14 +793,34 @@ def _numeric_binary_method(
793793
Tries to find the minimum cost conversion of either the LHS or the RHS, by finding all methods with either
794794
the LHS or the RHS as exactly the right type and then upcasting the other to that type.
795795
"""
796+
from .conversion import ( # noqa: PLC0415
797+
ConvertError,
798+
convert_to_same_type,
799+
min_binary_conversion,
800+
resolve_type,
801+
)
802+
803+
# 1. switch if reversed method
804+
if r_method:
805+
self, other = other, self
796806
# First check if we have a preserved method for this:
797807
if isinstance(self, RuntimeExpr) and (
798-
(preserved_method := self.__egg_class_decl__.preserved_methods.get(method_name)) is not None
808+
(preserved_method := self.__egg_class_decl__.preserved_methods.get(name)) is not None
799809
):
800810
return preserved_method.__get__(self)(other)
801-
# 1. switch if reversed method
802-
if r_method:
803-
self, other = other, self
811+
# Then check if the self is a Python type and the other side has a preserved method
812+
if (
813+
not isinstance(self, RuntimeExpr)
814+
and isinstance(other, RuntimeExpr)
815+
and ((preserved_method := other.__egg_class_decl__.preserved_methods.get(name)) is not None)
816+
):
817+
try:
818+
new_self = convert_to_same_type(self, other)
819+
except ConvertError:
820+
pass
821+
else:
822+
return preserved_method.__get__(new_self)(other)
823+
804824
# If the types don't exactly match to start, then we need to try converting one of them, by finding the cheapest conversion
805825
if not (
806826
isinstance(self, RuntimeExpr)
@@ -811,8 +831,6 @@ def _numeric_binary_method(
811831
)
812832
)
813833
):
814-
from .conversion import min_binary_conversion, resolve_type # noqa: PLC0415
815-
816834
best_method = min_binary_conversion(name, resolve_type(self), resolve_type(other))
817835

818836
if not best_method:

python/tests/test_high_level.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1552,3 +1552,19 @@ def __radd__(self, other: XLike) -> X: ...
15521552

15531553
assert X(1) + 2 == X(1) + X(2)
15541554
assert 2 + X(1) == X(2) + X(1)
1555+
1556+
1557+
def test_binary_preserved():
1558+
class X(Expr):
1559+
def __init__(self, value: i64Like) -> None: ...
1560+
1561+
@method(preserve=True)
1562+
def __add__(self, other: T) -> tuple[X, T]:
1563+
return (self, other)
1564+
1565+
def __radd__(self, other: object) -> tuple[X, X]: ...
1566+
1567+
converter(i64, X, X)
1568+
1569+
assert X(1) + 10 == (X(1), 10)
1570+
assert 10 + X(1) == (X(10), X(1))

0 commit comments

Comments
 (0)