Skip to content

Commit f214e23

Browse files
Try moving binary upcasting to conversion file
1 parent 2691387 commit f214e23

File tree

3 files changed

+40
-66
lines changed

3 files changed

+40
-66
lines changed

python/egglog/conversion.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -143,38 +143,38 @@ def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type:
143143
return tp
144144

145145

146-
# def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
147-
# """
148-
# Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
149-
# """
150-
# decls = _retrieve_conversion_decls().copy()
151-
# if isinstance(a, RuntimeExpr):
152-
# decls |= a
153-
# if isinstance(b, RuntimeExpr):
154-
# decls |= b
155-
156-
# a_tp = _get_tp(a)
157-
# b_tp = _get_tp(b)
158-
# # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
159-
# if not (
160-
# (isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
161-
# or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
162-
# ):
163-
# raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
164-
# a_converts_to = {
165-
# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
166-
# }
167-
# b_converts_to = {
168-
# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
169-
# }
170-
# if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
171-
# a_converts_to[a_tp] = 0
172-
# if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
173-
# b_converts_to[b_tp] = 0
174-
# common = set(a_converts_to) & set(b_converts_to)
175-
# if not common:
176-
# raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
177-
# return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
146+
def min_binary_conversion(
147+
method_name: str, lhs: type | JustTypeRef, rhs: type | JustTypeRef
148+
) -> tuple[Callable[[Any], RuntimeExpr], Callable[[Any], RuntimeExpr]] | None:
149+
"""
150+
Given a binary method and two starting types for the LHS and RHS, return a pair of callable which will convert
151+
the LHS and RHS to appropriate types which support this method. If no such conversion is possible, return None.
152+
153+
It should return the types which minimize the total conversion cost. If one of the types is a Python type, then
154+
both of them can be converted. However, if both are egglog types, then only one of them can be converted.
155+
"""
156+
decls = retrieve_conversion_decls()
157+
# tuple of (cost, convert_self)
158+
best_method: tuple[int, Callable[[Any], RuntimeExpr]] | None = None
159+
# Start by checking if we have a LHS that matches exactly and a RHS which can be converted
160+
if (
161+
isinstance(lhs, JustTypeRef)
162+
and (desired_other_type := decls.check_binary_method_with_self_type(method_name, lhs))
163+
and (converter := CONVERSIONS.get((rhs, desired_other_type)))
164+
):
165+
best_method = (converter[0], lambda x: x)
166+
167+
# Next see if it's possible to convert the LHS and keep the RHS as is
168+
if isinstance(rhs, JustTypeRef):
169+
decls = retrieve_conversion_decls()
170+
for desired_self_type in decls.check_binary_method_with_other_type(method_name, rhs):
171+
if converter := CONVERSIONS.get((lhs, desired_self_type)):
172+
cost, convert_self = converter
173+
if best_method is None or best_method[0] > cost:
174+
best_method = (cost, convert_self)
175+
if best_method is None:
176+
return None
177+
return best_method[1], best_method[1]
178178

179179

180180
def identity(x: object) -> object:

python/egglog/declarations.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,10 @@ def check_binary_method_with_self_type(self, method_name: str, self_type: JustTy
244244
Checks if the class has a binary method with the given name and self type. Returns the other type if it exists.
245245
"""
246246
vars: dict[ClassTypeVarRef, JustTypeRef] = {}
247-
if callable_decl := self._classes[self_type.name].methods.get(method_name):
247+
class_decl = self._classes.get(self_type.name)
248+
if class_decl is None:
249+
return None
250+
if callable_decl := class_decl.methods.get(method_name):
248251
match callable_decl.signature:
249252
case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(vars, self_type):
250253
return other_arg_type.to_just(vars)

python/egglog/runtime.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -649,42 +649,13 @@ def _numeric_binary_method(self: object, other: object, name: str = name, r_meth
649649
)
650650
)
651651
):
652-
from .conversion import CONVERSIONS, resolve_type, retrieve_conversion_decls # noqa: PLC0415
653-
654-
# tuple of (cost, convert_self)
655-
best_method: (
656-
tuple[
657-
int,
658-
Callable[[Any], RuntimeExpr],
659-
]
660-
| None
661-
) = None
662-
# Start by checking if we have a LHS that matches exactly and a RHS which can be converted
663-
if (
664-
isinstance(self, RuntimeExpr)
665-
and (
666-
desired_other_type := self.__egg_decls__.check_binary_method_with_self_type(
667-
name, self.__egg_typed_expr__.tp
668-
)
669-
)
670-
and (converter := CONVERSIONS.get((resolve_type(other), desired_other_type)))
671-
):
672-
best_method = (converter[0], lambda x: x)
673-
674-
# Next see if it's possible to convert the LHS and keep the RHS as is
675-
if isinstance(other, RuntimeExpr):
676-
decls = retrieve_conversion_decls()
677-
other_type = other.__egg_typed_expr__.tp
678-
resolved_self_type = resolve_type(self)
679-
for desired_self_type in decls.check_binary_method_with_other_type(name, other_type):
680-
if converter := CONVERSIONS.get((resolved_self_type, desired_self_type)):
681-
cost, convert_self = converter
682-
if best_method is None or best_method[0] > cost:
683-
best_method = (cost, convert_self)
652+
from .conversion import min_binary_conversion, resolve_type # noqa: PLC0415
653+
654+
best_method = min_binary_conversion(name, resolve_type(self), resolve_type(other))
684655

685656
if not best_method:
686657
raise RuntimeError(f"Cannot resolve {name} for {self} and {other}, no conversion found")
687-
self = best_method[1](self)
658+
self = best_method[0](self)
688659

689660
method_ref = MethodRef(self.__egg_class_name__, name)
690661
fn = RuntimeFunction(Thunk.value(self.__egg_decls__), Thunk.value(method_ref), self)

0 commit comments

Comments
 (0)