Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ _This project uses semantic versioning_

## UNRELEASED

- Change conversion between binary operators to consider converting both types [#320](https://github.com/egraphs-good/egglog-python/pull/320)
- Add ability to parse egglog expressions into Python values [#319](https://github.com/egraphs-good/egglog-python/pull/319)
- Deprecates `.eval()` method on primitives in favor of `.value` which can be used with pattern matching.
- Support methods like on expressions [#315](https://github.com/egraphs-good/egglog-python/pull/315)
Expand Down
15 changes: 5 additions & 10 deletions docs/reference/python-integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -379,13 +379,13 @@ instead of the normal mechanism which relies on `__getattr__`, you can call `egg
with the name of a method. This is only needed for third party code that inspects the type object itself to see if a
method is defined instead of just attempting to call it.

### Reflected methods
### Binary Method Conversions

Note that reflected methods (i.e. `__radd__`) are handled as a special case. If defined, they won't create their own egglog functions.
For [rich comparison methods](https://docs.python.org/3/reference/datamodel.html#object.__lt__) (like `__lt__`, `__le__`, `__eq__`, etc.) and [binary numeric methods](https://docs.python.org/3/reference/datamodel.html#object.__add__) (like `__add__`, `__sub__`, etc.), some more advanced conversion logic is needed to ensure they are converted properly. We add the `__r<name>__` methods for all expressions so that we can handle either position they are placed in.

Instead, whenever a reflected method is called, we will try to find the corresponding non-reflected method and call that instead.

Also, if a normal method fails because the arguments cannot be converted to the right types, the reflected version of the second arg will be tried.
If we have two values `lhs` and `rhs`, we will try to find the minimum cost conversion for both of them, and then call the method on the converted values.
If both are expression instances, we will convert at most one of them. However, if one is an expression and the other
is a different Python value (like an `int`), we will consider all possible conversions of both arguments to find the minimum.

```{code-cell} python
class Int(Expr):
Expand Down Expand Up @@ -423,11 +423,6 @@ converter(Int, Float, Float.from_int)
assert str(-1.0 + Int.var("x")) == "Float(-1.0) + Float.from_int(Int.var(\"x\"))"
```

For methods which allow returning `NotImplemented`, i.e. the comparison + binary math methods, we will also try upcasting both
types to the type which is lowest cost to convert both to.

For example, if you have `Float` and `Int` wrapper types and you have write the expr `-1.0 + Int.var("x")` you might want the result to be `Float(-1.0) + Float.from_int(Int.var("x"))`:

### Mutating arguments

In order to support Python functions and methods which mutate their arguments, you can pass in the `mutate_first_arg` keyword argument to the `@function` decorator and the `mutates_self` argument to the `@method` decorator. This will cause the first argument to be mutated in place, instead of being copied.
Expand Down
82 changes: 47 additions & 35 deletions python/egglog/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,41 +143,53 @@ def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type:
return tp


# def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
# """
# Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
# """
# decls = _retrieve_conversion_decls().copy()
# if isinstance(a, RuntimeExpr):
# decls |= a
# if isinstance(b, RuntimeExpr):
# decls |= b

# a_tp = _get_tp(a)
# b_tp = _get_tp(b)
# # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
# if not (
# (isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
# or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
# ):
# raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
# a_converts_to = {
# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
# }
# b_converts_to = {
# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
# }
# if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
# a_converts_to[a_tp] = 0
# if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
# b_converts_to[b_tp] = 0
# common = set(a_converts_to) & set(b_converts_to)
# if not common:
# raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
# return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])


def identity(x: object) -> object:
def min_binary_conversion(
method_name: str, lhs: type | JustTypeRef, rhs: type | JustTypeRef
) -> tuple[Callable[[Any], RuntimeExpr], Callable[[Any], RuntimeExpr]] | None:
"""
Given a binary method and two starting types for the LHS and RHS, return a pair of callable which will convert
the LHS and RHS to appropriate types which support this method. If no such conversion is possible, return None.

It should return the types which minimize the total conversion cost. If one of the types is a Python type, then
both of them can be converted. However, if both are egglog types, then only one of them can be converted.
"""
decls = retrieve_conversion_decls()
# tuple of (cost, convert lhs, convert rhs)
best_method: tuple[int, Callable[[Any], RuntimeExpr], Callable[[Any], RuntimeExpr]] | None = None

possible_lhs = _all_conversions_from(lhs) if isinstance(lhs, type) else [(0, lhs, identity)]
possible_rhs = _all_conversions_from(rhs) if isinstance(rhs, type) else [(0, rhs, identity)]
for lhs_cost, lhs_converted_type, lhs_convert in possible_lhs:
# Start by checking if we have a LHS that matches exactly and a RHS which can be converted
if (desired_other_type := decls.check_binary_method_with_self_type(method_name, lhs_converted_type)) and (
converter := CONVERSIONS.get((rhs, desired_other_type))
):
cost = lhs_cost + converter[0]
if best_method is None or best_method[0] > cost:
best_method = (cost, lhs_convert, converter[1])

for rhs_cost, rhs_converted_type, rhs_convert in possible_rhs:
# Next see if it's possible to convert the LHS and keep the RHS as is
for desired_self_type in decls.check_binary_method_with_other_type(method_name, rhs_converted_type):
if converter := CONVERSIONS.get((lhs, desired_self_type)):
cost = rhs_cost + converter[0]
if best_method is None or best_method[0] > cost:
best_method = (cost, converter[1], rhs_convert)
if best_method is None:
return None
return best_method[1], best_method[2]


def _all_conversions_from(tp: JustTypeRef | type) -> list[tuple[int, JustTypeRef, Callable[[Any], RuntimeExpr]]]:
"""
Get all conversions from a type to other types.

Returns a list of tuples of (cost, target type, conversion function).
"""
return [(cost, target, fn) for (source, target), (cost, fn) in CONVERSIONS.items() if source == tp]


def identity(x: Any) -> Any:
return x


Expand Down
5 changes: 4 additions & 1 deletion python/egglog/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,10 @@ def check_binary_method_with_self_type(self, method_name: str, self_type: JustTy
Checks if the class has a binary method with the given name and self type. Returns the other type if it exists.
"""
vars: dict[ClassTypeVarRef, JustTypeRef] = {}
if callable_decl := self._classes[self_type.name].methods.get(method_name):
class_decl = self._classes.get(self_type.name)
if class_decl is None:
return None
if callable_decl := class_decl.methods.get(method_name):
match callable_decl.signature:
case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(vars, self_type):
return other_arg_type.to_just(vars)
Expand Down
1 change: 0 additions & 1 deletion python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@
"__weakref__",
"__orig_bases__",
"__annotations__",
"__hash__",
"__qualname__",
"__firstlineno__",
"__static_attributes__",
Expand Down
38 changes: 36 additions & 2 deletions python/egglog/exp/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ def __le__(self, other: IntLike) -> Boolean: ...
def __eq__(self, other: IntLike) -> Boolean: # type: ignore[override]
...

# add a hash so that this test can pass
# https://github.com/scikit-learn/scikit-learn/blob/6fd23fca53845b32b249f2b36051c081b65e2fab/sklearn/utils/validation.py#L486-L487
@method(preserve=True)
def __hash__(self) -> int:
egraph = _get_current_egraph()
egraph.register(self)
egraph.run(array_api_schedule)
simplified = egraph.extract(self)
return hash(cast("RuntimeExpr", simplified).__egg_typed_expr__)

def __round__(self, ndigits: OptionalIntLike = None) -> Int: ...

# TODO: Fix this?
# Make != always return a Bool, so that numpy.unique works on a tuple of ints
# In _unique1d
Expand Down Expand Up @@ -280,6 +292,8 @@ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int):
yield rewrite(Int.if_(TRUE, o, b), subsume=True).to(o)
yield rewrite(Int.if_(FALSE, o, b), subsume=True).to(b)

yield rewrite(o.__round__(OptionalInt.none)).to(o)

# Never cannot be equal to anything real
yield rule(eq(Int.NEVER).to(Int(i))).then(panic("Int.NEVER cannot be equal to any real int"))

Expand Down Expand Up @@ -354,8 +368,14 @@ def __add__(self, other: FloatLike) -> Float: ...
def __sub__(self, other: FloatLike) -> Float: ...

def __pow__(self, other: FloatLike) -> Float: ...
def __round__(self, ndigits: OptionalIntLike = None) -> Float: ...

def __eq__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
def __ne__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
def __lt__(self, other: FloatLike) -> Boolean: ...
def __le__(self, other: FloatLike) -> Boolean: ...
def __gt__(self, other: FloatLike) -> Boolean: ...
def __ge__(self, other: FloatLike) -> Boolean: ...


converter(float, Float, lambda x: Float(x))
Expand All @@ -366,9 +386,10 @@ def __eq__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]


@array_api_ruleset.register
def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat):
def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat, i_: Int):
return [
rule(eq(fl).to(Float(f))).then(set_(fl.to_f64).to(f)),
rewrite(Float.from_int(Int(i))).to(Float(f64.from_i64(i))),
rewrite(Float(f).abs()).to(Float(f), f >= 0.0),
rewrite(Float(f).abs()).to(Float(-f), f < 0.0),
# Convert from float to rationl, if its a whole number i.e. can be converted to int
Expand All @@ -383,11 +404,22 @@ def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat):
rewrite(Float.rational(r) - Float.rational(r1)).to(Float.rational(r - r1)),
rewrite(Float.rational(r) * Float.rational(r1)).to(Float.rational(r * r1)),
rewrite(Float(f) ** Float(f2)).to(Float(f**f2)),
# ==
# comparisons
rewrite(Float(f) == Float(f)).to(TRUE),
rewrite(Float(f) == Float(f2)).to(FALSE, ne(f).to(f2)),
rewrite(Float(f) != Float(f2)).to(TRUE, f != f2),
rewrite(Float(f) != Float(f)).to(FALSE),
rewrite(Float(f) >= Float(f2)).to(TRUE, f >= f2),
rewrite(Float(f) >= Float(f2)).to(FALSE, f < f2),
rewrite(Float(f) <= Float(f2)).to(TRUE, f <= f2),
rewrite(Float(f) <= Float(f2)).to(FALSE, f > f2),
rewrite(Float(f) > Float(f2)).to(TRUE, f > f2),
rewrite(Float(f) > Float(f2)).to(FALSE, f <= f2),
rewrite(Float(f) < Float(f2)).to(TRUE, f < f2),
rewrite(Float.rational(r) == Float.rational(r)).to(TRUE),
rewrite(Float.rational(r) == Float.rational(r1)).to(FALSE, ne(r).to(r1)),
# round
rewrite(Float.rational(r).__round__()).to(Float.rational(r.round())),
]


Expand Down Expand Up @@ -671,6 +703,8 @@ class OptionalInt(Expr, ruleset=array_api_ruleset):
def some(cls, value: Int) -> OptionalInt: ...


OptionalIntLike: TypeAlias = OptionalInt | IntLike | None

converter(type(None), OptionalInt, lambda _: OptionalInt.none)
converter(Int, OptionalInt, OptionalInt.some)

Expand Down
16 changes: 11 additions & 5 deletions python/egglog/exp/array_api_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,22 @@
X = TypeVar("X", bound=Callable)


def jit(fn: X) -> X:
def jit(
fn: X,
*,
handle_expr: Callable[[NDArray], None] | None = None,
handle_optimized_expr: Callable[[NDArray], None] | None = None,
) -> X:
"""
Jit compiles a function
"""
egraph, res, res_optimized, program = function_to_program(fn, save_egglog_string=False)
if handle_expr:
handle_expr(res)
if handle_optimized_expr:
handle_optimized_expr(res_optimized)
fn_program = EvalProgram(program, {"np": np})
fn = cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
fn.initial_expr = res # type: ignore[attr-defined]
fn.expr = res_optimized # type: ignore[attr-defined]
return fn
return cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object))


def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph, NDArray, NDArray, Program]:
Expand Down
2 changes: 1 addition & 1 deletion python/egglog/exp/array_api_program_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,6 @@ def bin_op(res: NDArray, op: str) -> Command:
yield rewrite(ndarray_program(abs(x))).to((Program("np.abs(") + ndarray_program(x) + ")").assign())

# asarray
yield rewrite(ndarray_program(asarray(x, odtype))).to(
yield rewrite(ndarray_program(asarray(x, odtype, OptionalBool.none, optional_device_))).to(
Program("np.asarray(") + ndarray_program(x) + ", " + optional_dtype_program(odtype) + ")"
)
3 changes: 3 additions & 0 deletions python/egglog/pretty.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def _call_inner( # noqa: C901, PLR0911, PLR0912
return f"{tp_ref}.{method_name}", args
case MethodRef(_class_name, method_name):
slf, *args = args
non_str_slf = slf
slf = self(slf, parens=True)
match method_name:
case _ if method_name in UNARY_METHODS:
Expand All @@ -410,6 +411,8 @@ def _call_inner( # noqa: C901, PLR0911, PLR0912
return f"del {slf}[{self(args[0], unwrap_lit=True)}]"
case "__setitem__":
return f"{slf}[{self(args[0], unwrap_lit=True)}] = {self(args[1], unwrap_lit=True)}"
case "__round__":
return "round", [non_str_slf, *args]
case _:
return f"{slf}.{method_name}", args
case ConstantRef(name):
Expand Down
48 changes: 11 additions & 37 deletions python/egglog/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@
"__and__",
"__xor__",
"__or__",
"__lt__",
"__le__",
"__gt__",
"__ge__",
}


Expand All @@ -85,10 +89,7 @@
"__pos__",
"__neg__",
"__invert__",
"__lt__",
"__le__",
"__gt__",
"__ge__",
"__round__",
}

# Set this globally so we can get access to PyObject when we have a type annotation of just object.
Expand Down Expand Up @@ -567,6 +568,8 @@ def __setstate__(self, d: tuple[Declarations, TypedExprDecl]) -> None:
self.__egg_typed_expr_thunk__ = Thunk.value(d[1])

def __hash__(self) -> int:
if (method := _get_expr_method(self, "__hash__")) is not None:
return cast("int", cast("Any", method()))
return hash(self.__egg_typed_expr__)

# Implement this directly to special case behavior where it transforms to an egraph equality, if it is not a
Expand Down Expand Up @@ -647,42 +650,13 @@ def _numeric_binary_method(self: object, other: object, name: str = name, r_meth
)
)
):
from .conversion import CONVERSIONS, resolve_type, retrieve_conversion_decls # noqa: PLC0415

# tuple of (cost, convert_self)
best_method: (
tuple[
int,
Callable[[Any], RuntimeExpr],
]
| None
) = None
# Start by checking if we have a LHS that matches exactly and a RHS which can be converted
if (
isinstance(self, RuntimeExpr)
and (
desired_other_type := self.__egg_decls__.check_binary_method_with_self_type(
name, self.__egg_typed_expr__.tp
)
)
and (converter := CONVERSIONS.get((resolve_type(other), desired_other_type)))
):
best_method = (converter[0], lambda x: x)

# Next see if it's possible to convert the LHS and keep the RHS as is
if isinstance(other, RuntimeExpr):
decls = retrieve_conversion_decls()
other_type = other.__egg_typed_expr__.tp
resolved_self_type = resolve_type(self)
for desired_self_type in decls.check_binary_method_with_other_type(name, other_type):
if converter := CONVERSIONS.get((resolved_self_type, desired_self_type)):
cost, convert_self = converter
if best_method is None or best_method[0] > cost:
best_method = (cost, convert_self)
from .conversion import min_binary_conversion, resolve_type # noqa: PLC0415

best_method = min_binary_conversion(name, resolve_type(self), resolve_type(other))

if not best_method:
raise RuntimeError(f"Cannot resolve {name} for {self} and {other}, no conversion found")
self = best_method[1](self)
self, other = best_method[0](self), best_method[1](other)

method_ref = MethodRef(self.__egg_class_name__, name)
fn = RuntimeFunction(Thunk.value(self.__egg_decls__), Thunk.value(method_ref), self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __fn(X, y):
_8[2, :,] = _14
_15 = _7 @ _8
_16 = X - _15
_17 = np.sqrt(np.asarray(np.array((float(1) / 147)), np.dtype(np.float64)))
_17 = np.sqrt(np.asarray(np.array(float(1 / 147)), np.dtype(np.float64)))
_18 = X[_0] - _8[0, :,]
_19 = X[_2] - _8[1, :,]
_20 = X[_4] - _8[2, :,]
Expand All @@ -49,7 +49,7 @@ def __fn(X, y):
_37 = _33[2][:_36, :,] / _29
_38 = _37.T / _33[1][:_36]
_39 = np.array(150) * _7
_40 = _39 * np.array((float(1) / 2))
_40 = _39 * np.array(float(1 / 2))
_41 = np.sqrt(_40)
_42 = _8 - _15
_43 = _41 * _42.T
Expand Down
Loading