Skip to content

Commit cbf9160

Browse files
Merge pull request #315 from egraphs-good/redo-conversion-logic
Support methods like `__array_function__` on expressions
2 parents bf04275 + 34fef64 commit cbf9160

File tree

11 files changed

+293
-246
lines changed

11 files changed

+293
-246
lines changed

docs/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ _This project uses semantic versioning_
44

55
## UNRELEASED
66

7+
- Support methods like on expressions [#315](https://github.com/egraphs-good/egglog-python/pull/315)
78
- Automatically Create Changelog Entry for PRs [#313](https://github.com/egraphs-good/egglog-python/pull/313)
89
- Upgrade egglog which includes new backend.
910
- Fixes implementation of the Python Object sort to work with objects with dupliating hashes but the same value.

docs/reference/python-integration.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,11 @@ Note that the following list of methods are only supported as "preserved" since
303303
- `__iter_`
304304
- `__index__`
305305

306+
If you want to register additional methods as always preserved and defined on the `Expr` class itself, if needed
307+
instead of the normal mechanism which relies on `__getattr__`, you can call `egglog.define_expr_method(name: str)`,
308+
with the name of a method. This is only needed for third party code that inspects the type object itself to see if a
309+
method is defined instead of just attempting to call it.
310+
306311
### Reflected methods
307312

308313
Note that reflected methods (i.e. `__radd__`) are handled as a special case. If defined, they won't create their own egglog functions.

python/egglog/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from . import config, ipython_magic # noqa: F401
66
from .bindings import EggSmolError # noqa: F401
77
from .builtins import * # noqa: UP029
8-
from .conversion import ConvertError, convert, converter, get_type_args # noqa: F401
8+
from .conversion import *
99
from .egraph import *
10+
from .runtime import define_expr_method as define_expr_method # noqa: PLC0414
1011

1112
del ipython_magic

python/egglog/conversion.py

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

33
from collections import defaultdict
4+
from collections.abc import Callable
45
from contextlib import contextmanager
56
from contextvars import ContextVar
67
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, TypeVar, cast
8+
from typing import TYPE_CHECKING, Any, TypeVar, cast
89

910
from .declarations import *
1011
from .pretty import *
@@ -13,22 +14,22 @@
1314
from .type_constraint_solver import TypeConstraintError
1415

1516
if TYPE_CHECKING:
16-
from collections.abc import Callable, Generator
17+
from collections.abc import Generator
1718

1819
from .egraph import BaseExpr
1920
from .type_constraint_solver import TypeConstraintSolver
2021

21-
__all__ = ["ConvertError", "convert", "convert_to_same_type", "converter", "resolve_literal"]
22+
__all__ = ["ConvertError", "convert", "converter", "get_type_args"]
2223
# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
23-
CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {}
24+
CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable[[Any], RuntimeExpr]]] = {}
2425
# Global declerations to store all convertable types so we can query if they have certain methods or not
2526
_CONVERSION_DECLS = Declarations.create()
2627
# Defer a list of declerations to be added to the global declerations, so that we can not trigger them procesing
2728
# until we need them
2829
_TO_PROCESS_DECLS: list[DeclerationsLike] = []
2930

3031

31-
def _retrieve_conversion_decls() -> Declarations:
32+
def retrieve_conversion_decls() -> Declarations:
3233
_CONVERSION_DECLS.update(*_TO_PROCESS_DECLS)
3334
_TO_PROCESS_DECLS.clear()
3435
return _CONVERSION_DECLS
@@ -49,10 +50,10 @@ def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost:
4950
to_type_name = process_tp(to_type)
5051
if not isinstance(to_type_name, JustTypeRef):
5152
raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
52-
_register_converter(process_tp(from_type), to_type_name, fn, cost)
53+
_register_converter(process_tp(from_type), to_type_name, cast("Callable[[Any], RuntimeExpr]", fn), cost)
5354

5455

55-
def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None:
56+
def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable[[Any], RuntimeExpr], cost: int) -> None:
5657
"""
5758
Registers a converter from some type to an egglog type, if not already registered.
5859
@@ -97,15 +98,15 @@ class _ComposedConverter:
9798
We use the dataclass instead of the lambda to make it easier to debug.
9899
"""
99100

100-
a_b: Callable
101-
b_c: Callable
101+
a_b: Callable[[Any], RuntimeExpr]
102+
b_c: Callable[[Any], RuntimeExpr]
102103
b_args: tuple[JustTypeRef, ...]
103104

104-
def __call__(self, x: object) -> object:
105+
def __call__(self, x: Any) -> RuntimeExpr:
105106
# if we have A -> B and B[C] -> D then we should use (C,) as the type args
106107
# when converting from A -> B
107108
if self.b_args:
108-
with with_type_args(self.b_args, _retrieve_conversion_decls):
109+
with with_type_args(self.b_args, retrieve_conversion_decls):
109110
first_res = self.a_b(x)
110111
else:
111112
first_res = self.a_b(x)
@@ -142,33 +143,38 @@ def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type:
142143
return tp
143144

144145

145-
def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
146-
"""
147-
Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
148-
"""
149-
decls = _retrieve_conversion_decls()
150-
a_tp = _get_tp(a)
151-
b_tp = _get_tp(b)
152-
# Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
153-
if not (
154-
(isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
155-
or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
156-
):
157-
raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
158-
a_converts_to = {
159-
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
160-
}
161-
b_converts_to = {
162-
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
163-
}
164-
if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
165-
a_converts_to[a_tp] = 0
166-
if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
167-
b_converts_to[b_tp] = 0
168-
common = set(a_converts_to) & set(b_converts_to)
169-
if not common:
170-
raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
171-
return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
146+
# def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
147+
# """
148+
# Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
149+
# """
150+
# decls = _retrieve_conversion_decls().copy()
151+
# if isinstance(a, RuntimeExpr):
152+
# decls |= a
153+
# if isinstance(b, RuntimeExpr):
154+
# decls |= b
155+
156+
# a_tp = _get_tp(a)
157+
# b_tp = _get_tp(b)
158+
# # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
159+
# if not (
160+
# (isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
161+
# or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
162+
# ):
163+
# raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
164+
# a_converts_to = {
165+
# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
166+
# }
167+
# b_converts_to = {
168+
# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
169+
# }
170+
# if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
171+
# a_converts_to[a_tp] = 0
172+
# if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
173+
# b_converts_to[b_tp] = 0
174+
# common = set(a_converts_to) & set(b_converts_to)
175+
# if not common:
176+
# raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
177+
# return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
172178

173179

174180
def identity(x: object) -> object:
@@ -197,7 +203,7 @@ def with_type_args(args: tuple[JustTypeRef, ...], decls: Callable[[], Declaratio
197203
def resolve_literal(
198204
tp: TypeOrVarRef,
199205
arg: object,
200-
decls: Callable[[], Declarations] = _retrieve_conversion_decls,
206+
decls: Callable[[], Declarations] = retrieve_conversion_decls,
201207
tcs: TypeConstraintSolver | None = None,
202208
cls_name: str | None = None,
203209
) -> RuntimeExpr:
@@ -208,12 +214,12 @@ def resolve_literal(
208214
209215
If it cannot be resolved, we assume that the value passed in will resolve it.
210216
"""
211-
arg_type = _get_tp(arg)
217+
arg_type = resolve_type(arg)
212218

213219
# If we have any type variables, dont bother trying to resolve the literal, just return the arg
214220
try:
215221
tp_just = tp.to_just()
216-
except NotImplementedError:
222+
except TypeVarError:
217223
# If this is a generic arg but passed in a non runtime expression, try to resolve the generic
218224
# args first based on the existing type constraint solver
219225
if tcs:
@@ -258,7 +264,7 @@ def _debug_print_converers():
258264
source_to_targets[source].append(target)
259265

260266

261-
def _get_tp(x: object) -> JustTypeRef | type:
267+
def resolve_type(x: object) -> JustTypeRef | type:
262268
if isinstance(x, RuntimeExpr):
263269
return x.__egg_typed_expr__.tp
264270
tp = type(x)

python/egglog/declarations.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
"SpecialFunctions",
7474
"TypeOrVarRef",
7575
"TypeRefWithVars",
76+
"TypeVarError",
7677
"TypedExprDecl",
7778
"UnboundVarDecl",
7879
"UnionDecl",
@@ -95,7 +96,7 @@ def __egg_decls__(self) -> Declarations:
9596
# Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
9697
# instead raise explicitly
9798
except AttributeError as err:
98-
msg = f"Cannot resolve declarations for {self}"
99+
msg = f"Cannot resolve declarations for {self}: {err}"
99100
raise RuntimeError(msg) from err
100101

101102

@@ -225,14 +226,43 @@ def set_function_decl(
225226
case _:
226227
assert_never(ref)
227228

228-
def has_method(self, class_name: str, method_name: str) -> bool | None:
229+
def check_binary_method_with_types(self, method_name: str, self_type: JustTypeRef, other_type: JustTypeRef) -> bool:
229230
"""
230-
Returns whether the given class has the given method, or None if we cant find the class.
231+
Checks if the class has a binary method compatible with the given types.
231232
"""
232-
if class_name in self._classes:
233-
return method_name in self._classes[class_name].methods
233+
vars: dict[ClassTypeVarRef, JustTypeRef] = {}
234+
if callable_decl := self._classes[self_type.name].methods.get(method_name):
235+
match callable_decl.signature:
236+
case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(
237+
vars, self_type
238+
) and other_arg_type.matches_just(vars, other_type):
239+
return True
240+
return False
241+
242+
def check_binary_method_with_self_type(self, method_name: str, self_type: JustTypeRef) -> JustTypeRef | None:
243+
"""
244+
Checks if the class has a binary method with the given name and self type. Returns the other type if it exists.
245+
"""
246+
vars: dict[ClassTypeVarRef, JustTypeRef] = {}
247+
if callable_decl := self._classes[self_type.name].methods.get(method_name):
248+
match callable_decl.signature:
249+
case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(vars, self_type):
250+
return other_arg_type.to_just(vars)
234251
return None
235252

253+
def check_binary_method_with_other_type(self, method_name: str, other_type: JustTypeRef) -> Iterable[JustTypeRef]:
254+
"""
255+
Returns the types which are compatible with the given binary method name and other type.
256+
"""
257+
for class_decl in self._classes.values():
258+
vars: dict[ClassTypeVarRef, JustTypeRef] = {}
259+
if callable_decl := class_decl.methods.get(method_name):
260+
match callable_decl.signature:
261+
case FunctionSignature((self_arg_type, other_arg_type)) if other_arg_type.matches_just(
262+
vars, other_type
263+
):
264+
yield self_arg_type.to_just(vars)
265+
236266
def get_class_decl(self, name: str) -> ClassDecl:
237267
return self._classes[name]
238268

@@ -300,6 +330,10 @@ def __str__(self) -> str:
300330
_RESOLVED_TYPEVARS: dict[ClassTypeVarRef, TypeVar] = {}
301331

302332

333+
class TypeVarError(RuntimeError):
334+
"""Error when trying to resolve a type variable that doesn't exist."""
335+
336+
303337
@dataclass(frozen=True)
304338
class ClassTypeVarRef:
305339
"""
@@ -309,9 +343,10 @@ class ClassTypeVarRef:
309343
name: str
310344
module: str
311345

312-
def to_just(self) -> JustTypeRef:
313-
msg = f"{self}: egglog does not support generic classes yet."
314-
raise NotImplementedError(msg)
346+
def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef:
347+
if vars is None or self not in vars:
348+
raise TypeVarError(f"Cannot convert type variable {self} to concrete type without variable bindings")
349+
return vars[self]
315350

316351
def __str__(self) -> str:
317352
return str(self.to_type_var())
@@ -325,20 +360,39 @@ def from_type_var(cls, typevar: TypeVar) -> ClassTypeVarRef:
325360
def to_type_var(self) -> TypeVar:
326361
return _RESOLVED_TYPEVARS[self]
327362

363+
def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool:
364+
"""
365+
Checks if this type variable matches the given JustTypeRef, including type variables.
366+
"""
367+
if self in vars:
368+
return vars[self] == other
369+
vars[self] = other
370+
return True
371+
328372

329373
@dataclass(frozen=True)
330374
class TypeRefWithVars:
331375
name: str
332376
args: tuple[TypeOrVarRef, ...] = ()
333377

334-
def to_just(self) -> JustTypeRef:
335-
return JustTypeRef(self.name, tuple(a.to_just() for a in self.args))
378+
def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef:
379+
return JustTypeRef(self.name, tuple(a.to_just(vars) for a in self.args))
336380

337381
def __str__(self) -> str:
338382
if self.args:
339383
return f"{self.name}[{', '.join(str(a) for a in self.args)}]"
340384
return self.name
341385

386+
def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool:
387+
"""
388+
Checks if this type reference matches the given JustTypeRef, including type variables.
389+
"""
390+
return (
391+
self.name == other.name
392+
and len(self.args) == len(other.args)
393+
and all(a.matches_just(vars, b) for a, b in zip(self.args, other.args, strict=True))
394+
)
395+
342396

343397
TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars
344398

python/egglog/egraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@
115115
"__firstlineno__",
116116
"__static_attributes__",
117117
# Ignore all reflected binary method
118-
*REFLECTED_BINARY_METHODS.keys(),
118+
*(f"__r{m[2:]}" for m in NUMERIC_BINARY_METHODS),
119119
}
120120

121121

0 commit comments

Comments
 (0)