Skip to content

Commit bbb4df8

Browse files
Merge pull request #320 from egraphs-good/upgrade-python
Change conversion between binary operators to consider converting both types
2 parents 5f8da71 + b4bbe65 commit bbb4df8

17 files changed

+2156
-1492
lines changed

docs/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ _This project uses semantic versioning_
44

55
## UNRELEASED
66

7+
- Change conversion between binary operators to consider converting both types [#320](https://github.com/egraphs-good/egglog-python/pull/320)
78
- Add ability to parse egglog expressions into Python values [#319](https://github.com/egraphs-good/egglog-python/pull/319)
89
- Deprecates `.eval()` method on primitives in favor of `.value` which can be used with pattern matching.
910
- Support methods like on expressions [#315](https://github.com/egraphs-good/egglog-python/pull/315)

docs/reference/python-integration.md

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -379,13 +379,13 @@ instead of the normal mechanism which relies on `__getattr__`, you can call `egg
379379
with the name of a method. This is only needed for third party code that inspects the type object itself to see if a
380380
method is defined instead of just attempting to call it.
381381

382-
### Reflected methods
382+
### Binary Method Conversions
383383

384-
Note that reflected methods (i.e. `__radd__`) are handled as a special case. If defined, they won't create their own egglog functions.
384+
For [rich comparison methods](https://docs.python.org/3/reference/datamodel.html#object.__lt__) (like `__lt__`, `__le__`, `__eq__`, etc.) and [binary numeric methods](https://docs.python.org/3/reference/datamodel.html#object.__add__) (like `__add__`, `__sub__`, etc.), some more advanced conversion logic is needed to ensure they are converted properly. We add the `__r<name>__` methods for all expressions so that we can handle either position they are placed in.
385385

386-
Instead, whenever a reflected method is called, we will try to find the corresponding non-reflected method and call that instead.
387-
388-
Also, if a normal method fails because the arguments cannot be converted to the right types, the reflected version of the second arg will be tried.
386+
If we have two values `lhs` and `rhs`, we will try to find the minimum cost conversion for both of them, and then call the method on the converted values.
387+
If both are expression instances, we will convert at most one of them. However, if one is an expression and the other
388+
is a different Python value (like an `int`), we will consider all possible conversions of both arguments to find the minimum.
389389

390390
```{code-cell} python
391391
class Int(Expr):
@@ -423,11 +423,6 @@ converter(Int, Float, Float.from_int)
423423
assert str(-1.0 + Int.var("x")) == "Float(-1.0) + Float.from_int(Int.var(\"x\"))"
424424
```
425425

426-
For methods which allow returning `NotImplemented`, i.e. the comparison + binary math methods, we will also try upcasting both
427-
types to the type which is lowest cost to convert both to.
428-
429-
For example, if you have `Float` and `Int` wrapper types and you have write the expr `-1.0 + Int.var("x")` you might want the result to be `Float(-1.0) + Float.from_int(Int.var("x"))`:
430-
431426
### Mutating arguments
432427

433428
In order to support Python functions and methods which mutate their arguments, you can pass in the `mutate_first_arg` keyword argument to the `@function` decorator and the `mutates_self` argument to the `@method` decorator. This will cause the first argument to be mutated in place, instead of being copied.

python/egglog/conversion.py

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -143,41 +143,53 @@ def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type:
143143
return tp
144144

145145

146-
# def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
147-
# """
148-
# Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
149-
# """
150-
# decls = _retrieve_conversion_decls().copy()
151-
# if isinstance(a, RuntimeExpr):
152-
# decls |= a
153-
# if isinstance(b, RuntimeExpr):
154-
# decls |= b
155-
156-
# a_tp = _get_tp(a)
157-
# b_tp = _get_tp(b)
158-
# # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
159-
# if not (
160-
# (isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
161-
# or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
162-
# ):
163-
# raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
164-
# a_converts_to = {
165-
# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
166-
# }
167-
# b_converts_to = {
168-
# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
169-
# }
170-
# if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
171-
# a_converts_to[a_tp] = 0
172-
# if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
173-
# b_converts_to[b_tp] = 0
174-
# common = set(a_converts_to) & set(b_converts_to)
175-
# if not common:
176-
# raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
177-
# return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
178-
179-
180-
def identity(x: object) -> object:
146+
def min_binary_conversion(
147+
method_name: str, lhs: type | JustTypeRef, rhs: type | JustTypeRef
148+
) -> tuple[Callable[[Any], RuntimeExpr], Callable[[Any], RuntimeExpr]] | None:
149+
"""
150+
Given a binary method and two starting types for the LHS and RHS, return a pair of callable which will convert
151+
the LHS and RHS to appropriate types which support this method. If no such conversion is possible, return None.
152+
153+
It should return the types which minimize the total conversion cost. If one of the types is a Python type, then
154+
both of them can be converted. However, if both are egglog types, then only one of them can be converted.
155+
"""
156+
decls = retrieve_conversion_decls()
157+
# tuple of (cost, convert lhs, convert rhs)
158+
best_method: tuple[int, Callable[[Any], RuntimeExpr], Callable[[Any], RuntimeExpr]] | None = None
159+
160+
possible_lhs = _all_conversions_from(lhs) if isinstance(lhs, type) else [(0, lhs, identity)]
161+
possible_rhs = _all_conversions_from(rhs) if isinstance(rhs, type) else [(0, rhs, identity)]
162+
for lhs_cost, lhs_converted_type, lhs_convert in possible_lhs:
163+
# Start by checking if we have a LHS that matches exactly and a RHS which can be converted
164+
if (desired_other_type := decls.check_binary_method_with_self_type(method_name, lhs_converted_type)) and (
165+
converter := CONVERSIONS.get((rhs, desired_other_type))
166+
):
167+
cost = lhs_cost + converter[0]
168+
if best_method is None or best_method[0] > cost:
169+
best_method = (cost, lhs_convert, converter[1])
170+
171+
for rhs_cost, rhs_converted_type, rhs_convert in possible_rhs:
172+
# Next see if it's possible to convert the LHS and keep the RHS as is
173+
for desired_self_type in decls.check_binary_method_with_other_type(method_name, rhs_converted_type):
174+
if converter := CONVERSIONS.get((lhs, desired_self_type)):
175+
cost = rhs_cost + converter[0]
176+
if best_method is None or best_method[0] > cost:
177+
best_method = (cost, converter[1], rhs_convert)
178+
if best_method is None:
179+
return None
180+
return best_method[1], best_method[2]
181+
182+
183+
def _all_conversions_from(tp: JustTypeRef | type) -> list[tuple[int, JustTypeRef, Callable[[Any], RuntimeExpr]]]:
184+
"""
185+
Get all conversions from a type to other types.
186+
187+
Returns a list of tuples of (cost, target type, conversion function).
188+
"""
189+
return [(cost, target, fn) for (source, target), (cost, fn) in CONVERSIONS.items() if source == tp]
190+
191+
192+
def identity(x: Any) -> Any:
181193
return x
182194

183195

python/egglog/declarations.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,10 @@ def check_binary_method_with_self_type(self, method_name: str, self_type: JustTy
244244
Checks if the class has a binary method with the given name and self type. Returns the other type if it exists.
245245
"""
246246
vars: dict[ClassTypeVarRef, JustTypeRef] = {}
247-
if callable_decl := self._classes[self_type.name].methods.get(method_name):
247+
class_decl = self._classes.get(self_type.name)
248+
if class_decl is None:
249+
return None
250+
if callable_decl := class_decl.methods.get(method_name):
248251
match callable_decl.signature:
249252
case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(vars, self_type):
250253
return other_arg_type.to_just(vars)

python/egglog/egraph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@
110110
"__weakref__",
111111
"__orig_bases__",
112112
"__annotations__",
113-
"__hash__",
114113
"__qualname__",
115114
"__firstlineno__",
116115
"__static_attributes__",

python/egglog/exp/array_api.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,18 @@ def __le__(self, other: IntLike) -> Boolean: ...
154154
def __eq__(self, other: IntLike) -> Boolean: # type: ignore[override]
155155
...
156156

157+
# add a hash so that this test can pass
158+
# https://github.com/scikit-learn/scikit-learn/blob/6fd23fca53845b32b249f2b36051c081b65e2fab/sklearn/utils/validation.py#L486-L487
159+
@method(preserve=True)
160+
def __hash__(self) -> int:
161+
egraph = _get_current_egraph()
162+
egraph.register(self)
163+
egraph.run(array_api_schedule)
164+
simplified = egraph.extract(self)
165+
return hash(cast("RuntimeExpr", simplified).__egg_typed_expr__)
166+
167+
def __round__(self, ndigits: OptionalIntLike = None) -> Int: ...
168+
157169
# TODO: Fix this?
158170
# Make != always return a Bool, so that numpy.unique works on a tuple of ints
159171
# In _unique1d
@@ -280,6 +292,8 @@ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int):
280292
yield rewrite(Int.if_(TRUE, o, b), subsume=True).to(o)
281293
yield rewrite(Int.if_(FALSE, o, b), subsume=True).to(b)
282294

295+
yield rewrite(o.__round__(OptionalInt.none)).to(o)
296+
283297
# Never cannot be equal to anything real
284298
yield rule(eq(Int.NEVER).to(Int(i))).then(panic("Int.NEVER cannot be equal to any real int"))
285299

@@ -354,8 +368,14 @@ def __add__(self, other: FloatLike) -> Float: ...
354368
def __sub__(self, other: FloatLike) -> Float: ...
355369

356370
def __pow__(self, other: FloatLike) -> Float: ...
371+
def __round__(self, ndigits: OptionalIntLike = None) -> Float: ...
357372

358373
def __eq__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
374+
def __ne__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
375+
def __lt__(self, other: FloatLike) -> Boolean: ...
376+
def __le__(self, other: FloatLike) -> Boolean: ...
377+
def __gt__(self, other: FloatLike) -> Boolean: ...
378+
def __ge__(self, other: FloatLike) -> Boolean: ...
359379

360380

361381
converter(float, Float, lambda x: Float(x))
@@ -366,9 +386,10 @@ def __eq__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
366386

367387

368388
@array_api_ruleset.register
369-
def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat):
389+
def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat, i_: Int):
370390
return [
371391
rule(eq(fl).to(Float(f))).then(set_(fl.to_f64).to(f)),
392+
rewrite(Float.from_int(Int(i))).to(Float(f64.from_i64(i))),
372393
rewrite(Float(f).abs()).to(Float(f), f >= 0.0),
373394
rewrite(Float(f).abs()).to(Float(-f), f < 0.0),
374395
# Convert from float to rationl, if its a whole number i.e. can be converted to int
@@ -383,11 +404,22 @@ def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat):
383404
rewrite(Float.rational(r) - Float.rational(r1)).to(Float.rational(r - r1)),
384405
rewrite(Float.rational(r) * Float.rational(r1)).to(Float.rational(r * r1)),
385406
rewrite(Float(f) ** Float(f2)).to(Float(f**f2)),
386-
# ==
407+
# comparisons
387408
rewrite(Float(f) == Float(f)).to(TRUE),
388409
rewrite(Float(f) == Float(f2)).to(FALSE, ne(f).to(f2)),
410+
rewrite(Float(f) != Float(f2)).to(TRUE, f != f2),
411+
rewrite(Float(f) != Float(f)).to(FALSE),
412+
rewrite(Float(f) >= Float(f2)).to(TRUE, f >= f2),
413+
rewrite(Float(f) >= Float(f2)).to(FALSE, f < f2),
414+
rewrite(Float(f) <= Float(f2)).to(TRUE, f <= f2),
415+
rewrite(Float(f) <= Float(f2)).to(FALSE, f > f2),
416+
rewrite(Float(f) > Float(f2)).to(TRUE, f > f2),
417+
rewrite(Float(f) > Float(f2)).to(FALSE, f <= f2),
418+
rewrite(Float(f) < Float(f2)).to(TRUE, f < f2),
389419
rewrite(Float.rational(r) == Float.rational(r)).to(TRUE),
390420
rewrite(Float.rational(r) == Float.rational(r1)).to(FALSE, ne(r).to(r1)),
421+
# round
422+
rewrite(Float.rational(r).__round__()).to(Float.rational(r.round())),
391423
]
392424

393425

@@ -671,6 +703,8 @@ class OptionalInt(Expr, ruleset=array_api_ruleset):
671703
def some(cls, value: Int) -> OptionalInt: ...
672704

673705

706+
OptionalIntLike: TypeAlias = OptionalInt | IntLike | None
707+
674708
converter(type(None), OptionalInt, lambda _: OptionalInt.none)
675709
converter(Int, OptionalInt, OptionalInt.some)
676710

python/egglog/exp/array_api_jit.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,22 @@
1414
X = TypeVar("X", bound=Callable)
1515

1616

17-
def jit(fn: X) -> X:
17+
def jit(
18+
fn: X,
19+
*,
20+
handle_expr: Callable[[NDArray], None] | None = None,
21+
handle_optimized_expr: Callable[[NDArray], None] | None = None,
22+
) -> X:
1823
"""
1924
Jit compiles a function
2025
"""
2126
egraph, res, res_optimized, program = function_to_program(fn, save_egglog_string=False)
27+
if handle_expr:
28+
handle_expr(res)
29+
if handle_optimized_expr:
30+
handle_optimized_expr(res_optimized)
2231
fn_program = EvalProgram(program, {"np": np})
23-
fn = cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
24-
fn.initial_expr = res # type: ignore[attr-defined]
25-
fn.expr = res_optimized # type: ignore[attr-defined]
26-
return fn
32+
return cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
2733

2834

2935
def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph, NDArray, NDArray, Program]:

python/egglog/exp/array_api_program_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,6 @@ def bin_op(res: NDArray, op: str) -> Command:
505505
yield rewrite(ndarray_program(abs(x))).to((Program("np.abs(") + ndarray_program(x) + ")").assign())
506506

507507
# asarray
508-
yield rewrite(ndarray_program(asarray(x, odtype))).to(
508+
yield rewrite(ndarray_program(asarray(x, odtype, OptionalBool.none, optional_device_))).to(
509509
Program("np.asarray(") + ndarray_program(x) + ", " + optional_dtype_program(odtype) + ")"
510510
)

python/egglog/pretty.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ def _call_inner( # noqa: C901, PLR0911, PLR0912
394394
return f"{tp_ref}.{method_name}", args
395395
case MethodRef(_class_name, method_name):
396396
slf, *args = args
397+
non_str_slf = slf
397398
slf = self(slf, parens=True)
398399
match method_name:
399400
case _ if method_name in UNARY_METHODS:
@@ -410,6 +411,8 @@ def _call_inner( # noqa: C901, PLR0911, PLR0912
410411
return f"del {slf}[{self(args[0], unwrap_lit=True)}]"
411412
case "__setitem__":
412413
return f"{slf}[{self(args[0], unwrap_lit=True)}] = {self(args[1], unwrap_lit=True)}"
414+
case "__round__":
415+
return "round", [non_str_slf, *args]
413416
case _:
414417
return f"{slf}.{method_name}", args
415418
case ConstantRef(name):

python/egglog/runtime.py

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@
6565
"__and__",
6666
"__xor__",
6767
"__or__",
68+
"__lt__",
69+
"__le__",
70+
"__gt__",
71+
"__ge__",
6872
}
6973

7074

@@ -85,10 +89,7 @@
8589
"__pos__",
8690
"__neg__",
8791
"__invert__",
88-
"__lt__",
89-
"__le__",
90-
"__gt__",
91-
"__ge__",
92+
"__round__",
9293
}
9394

9495
# Set this globally so we can get access to PyObject when we have a type annotation of just object.
@@ -567,6 +568,8 @@ def __setstate__(self, d: tuple[Declarations, TypedExprDecl]) -> None:
567568
self.__egg_typed_expr_thunk__ = Thunk.value(d[1])
568569

569570
def __hash__(self) -> int:
571+
if (method := _get_expr_method(self, "__hash__")) is not None:
572+
return cast("int", cast("Any", method()))
570573
return hash(self.__egg_typed_expr__)
571574

572575
# Implement this directly to special case behavior where it transforms to an egraph equality, if it is not a
@@ -647,42 +650,13 @@ def _numeric_binary_method(self: object, other: object, name: str = name, r_meth
647650
)
648651
)
649652
):
650-
from .conversion import CONVERSIONS, resolve_type, retrieve_conversion_decls # noqa: PLC0415
651-
652-
# tuple of (cost, convert_self)
653-
best_method: (
654-
tuple[
655-
int,
656-
Callable[[Any], RuntimeExpr],
657-
]
658-
| None
659-
) = None
660-
# Start by checking if we have a LHS that matches exactly and a RHS which can be converted
661-
if (
662-
isinstance(self, RuntimeExpr)
663-
and (
664-
desired_other_type := self.__egg_decls__.check_binary_method_with_self_type(
665-
name, self.__egg_typed_expr__.tp
666-
)
667-
)
668-
and (converter := CONVERSIONS.get((resolve_type(other), desired_other_type)))
669-
):
670-
best_method = (converter[0], lambda x: x)
671-
672-
# Next see if it's possible to convert the LHS and keep the RHS as is
673-
if isinstance(other, RuntimeExpr):
674-
decls = retrieve_conversion_decls()
675-
other_type = other.__egg_typed_expr__.tp
676-
resolved_self_type = resolve_type(self)
677-
for desired_self_type in decls.check_binary_method_with_other_type(name, other_type):
678-
if converter := CONVERSIONS.get((resolved_self_type, desired_self_type)):
679-
cost, convert_self = converter
680-
if best_method is None or best_method[0] > cost:
681-
best_method = (cost, convert_self)
653+
from .conversion import min_binary_conversion, resolve_type # noqa: PLC0415
654+
655+
best_method = min_binary_conversion(name, resolve_type(self), resolve_type(other))
682656

683657
if not best_method:
684658
raise RuntimeError(f"Cannot resolve {name} for {self} and {other}, no conversion found")
685-
self = best_method[1](self)
659+
self, other = best_method[0](self), best_method[1](other)
686660

687661
method_ref = MethodRef(self.__egg_class_name__, name)
688662
fn = RuntimeFunction(Thunk.value(self.__egg_decls__), Thunk.value(method_ref), self)

0 commit comments

Comments
 (0)