Skip to content

Commit ea91781

Browse files
Add support for adding methods actually defined on classes
Adds support for methods like `__array_function__` which actually need to be added on the class as actual methods, not through overloading `__getattr__`. Custom methods can be registered by third party libraries. This PR also redoes the logic for upcasting when using binary operations. Instead of upcasting both values, it will only ever upcast one, choosing whichever one would be cheaper to upcast. This leads to more predictable behavior.
1 parent 00d1b65 commit ea91781

File tree

10 files changed

+292
-246
lines changed

10 files changed

+292
-246
lines changed

docs/reference/python-integration.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,11 @@ Note that the following list of methods are only supported as "preserved" since
303303
- `__iter_`
304304
- `__index__`
305305

306+
If you want to register additional methods as always preserved and defined on the `Expr` class itself, if needed
307+
instead of the normal mechanism which relies on `__getattr__`, you can call `egglog.define_expr_method(name: str)`,
308+
with the name of a method. This is only needed for third party code that inspects the type object itself to see if a
309+
method is defined instead of just attempting to call it.
310+
306311
### Reflected methods
307312

308313
Note that reflected methods (i.e. `__radd__`) are handled as a special case. If defined, they won't create their own egglog functions.

python/egglog/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from . import config, ipython_magic # noqa: F401
66
from .bindings import EggSmolError # noqa: F401
77
from .builtins import * # noqa: UP029
8-
from .conversion import ConvertError, convert, converter, get_type_args # noqa: F401
8+
from .conversion import *
99
from .egraph import *
10+
from .runtime import define_expr_method as define_expr_method # noqa: PLC0414
1011

1112
del ipython_magic

python/egglog/conversion.py

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

33
from collections import defaultdict
4+
from collections.abc import Callable
45
from contextlib import contextmanager
56
from contextvars import ContextVar
67
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, TypeVar, cast
8+
from typing import TYPE_CHECKING, Any, TypeVar, cast
89

910
from .declarations import *
1011
from .pretty import *
@@ -13,22 +14,22 @@
1314
from .type_constraint_solver import TypeConstraintError
1415

1516
if TYPE_CHECKING:
16-
from collections.abc import Callable, Generator
17+
from collections.abc import Generator
1718

1819
from .egraph import BaseExpr
1920
from .type_constraint_solver import TypeConstraintSolver
2021

21-
__all__ = ["ConvertError", "convert", "convert_to_same_type", "converter", "resolve_literal"]
22+
__all__ = ["ConvertError", "convert", "converter", "get_type_args"]
2223
# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
23-
CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {}
24+
CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable[[Any], RuntimeExpr]]] = {}
2425
# Global declerations to store all convertable types so we can query if they have certain methods or not
2526
_CONVERSION_DECLS = Declarations.create()
2627
# Defer a list of declerations to be added to the global declerations, so that we can not trigger them procesing
2728
# until we need them
2829
_TO_PROCESS_DECLS: list[DeclerationsLike] = []
2930

3031

31-
def _retrieve_conversion_decls() -> Declarations:
32+
def retrieve_conversion_decls() -> Declarations:
3233
_CONVERSION_DECLS.update(*_TO_PROCESS_DECLS)
3334
_TO_PROCESS_DECLS.clear()
3435
return _CONVERSION_DECLS
@@ -49,10 +50,10 @@ def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost:
4950
to_type_name = process_tp(to_type)
5051
if not isinstance(to_type_name, JustTypeRef):
5152
raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
52-
_register_converter(process_tp(from_type), to_type_name, fn, cost)
53+
_register_converter(process_tp(from_type), to_type_name, cast("Callable[[Any], RuntimeExpr]", fn), cost)
5354

5455

55-
def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None:
56+
def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable[[Any], RuntimeExpr], cost: int) -> None:
5657
"""
5758
Registers a converter from some type to an egglog type, if not already registered.
5859
@@ -97,15 +98,15 @@ class _ComposedConverter:
9798
We use the dataclass instead of the lambda to make it easier to debug.
9899
"""
99100

100-
a_b: Callable
101-
b_c: Callable
101+
a_b: Callable[[Any], RuntimeExpr]
102+
b_c: Callable[[Any], RuntimeExpr]
102103
b_args: tuple[JustTypeRef, ...]
103104

104-
def __call__(self, x: object) -> object:
105+
def __call__(self, x: Any) -> RuntimeExpr:
105106
# if we have A -> B and B[C] -> D then we should use (C,) as the type args
106107
# when converting from A -> B
107108
if self.b_args:
108-
with with_type_args(self.b_args, _retrieve_conversion_decls):
109+
with with_type_args(self.b_args, retrieve_conversion_decls):
109110
first_res = self.a_b(x)
110111
else:
111112
first_res = self.a_b(x)
@@ -142,33 +143,38 @@ def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type:
142143
return tp
143144

144145

145-
def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
146-
"""
147-
Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
148-
"""
149-
decls = _retrieve_conversion_decls()
150-
a_tp = _get_tp(a)
151-
b_tp = _get_tp(b)
152-
# Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
153-
if not (
154-
(isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
155-
or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
156-
):
157-
raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
158-
a_converts_to = {
159-
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
160-
}
161-
b_converts_to = {
162-
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
163-
}
164-
if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
165-
a_converts_to[a_tp] = 0
166-
if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
167-
b_converts_to[b_tp] = 0
168-
common = set(a_converts_to) & set(b_converts_to)
169-
if not common:
170-
raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
171-
return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
146+
# def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
147+
# """
148+
# Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
149+
# """
150+
# decls = _retrieve_conversion_decls().copy()
151+
# if isinstance(a, RuntimeExpr):
152+
# decls |= a
153+
# if isinstance(b, RuntimeExpr):
154+
# decls |= b
155+
156+
# a_tp = _get_tp(a)
157+
# b_tp = _get_tp(b)
158+
# # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
159+
# if not (
160+
# (isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
161+
# or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
162+
# ):
163+
# raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
164+
# a_converts_to = {
165+
# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
166+
# }
167+
# b_converts_to = {
168+
# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
169+
# }
170+
# if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
171+
# a_converts_to[a_tp] = 0
172+
# if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
173+
# b_converts_to[b_tp] = 0
174+
# common = set(a_converts_to) & set(b_converts_to)
175+
# if not common:
176+
# raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
177+
# return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
172178

173179

174180
def identity(x: object) -> object:
@@ -197,7 +203,7 @@ def with_type_args(args: tuple[JustTypeRef, ...], decls: Callable[[], Declaratio
197203
def resolve_literal(
198204
tp: TypeOrVarRef,
199205
arg: object,
200-
decls: Callable[[], Declarations] = _retrieve_conversion_decls,
206+
decls: Callable[[], Declarations] = retrieve_conversion_decls,
201207
tcs: TypeConstraintSolver | None = None,
202208
cls_name: str | None = None,
203209
) -> RuntimeExpr:
@@ -208,12 +214,12 @@ def resolve_literal(
208214
209215
If it cannot be resolved, we assume that the value passed in will resolve it.
210216
"""
211-
arg_type = _get_tp(arg)
217+
arg_type = resolve_type(arg)
212218

213219
# If we have any type variables, dont bother trying to resolve the literal, just return the arg
214220
try:
215221
tp_just = tp.to_just()
216-
except NotImplementedError:
222+
except TypeVarError:
217223
# If this is a generic arg but passed in a non runtime expression, try to resolve the generic
218224
# args first based on the existing type constraint solver
219225
if tcs:
@@ -258,7 +264,7 @@ def _debug_print_converers():
258264
source_to_targets[source].append(target)
259265

260266

261-
def _get_tp(x: object) -> JustTypeRef | type:
267+
def resolve_type(x: object) -> JustTypeRef | type:
262268
if isinstance(x, RuntimeExpr):
263269
return x.__egg_typed_expr__.tp
264270
tp = type(x)

python/egglog/declarations.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
"SpecialFunctions",
7474
"TypeOrVarRef",
7575
"TypeRefWithVars",
76+
"TypeVarError",
7677
"TypedExprDecl",
7778
"UnboundVarDecl",
7879
"UnionDecl",
@@ -95,7 +96,7 @@ def __egg_decls__(self) -> Declarations:
9596
# Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
9697
# instead raise explicitly
9798
except AttributeError as err:
98-
msg = f"Cannot resolve declarations for {self}"
99+
msg = f"Cannot resolve declarations {err}"
99100
raise RuntimeError(msg) from err
100101

101102

@@ -225,14 +226,43 @@ def set_function_decl(
225226
case _:
226227
assert_never(ref)
227228

228-
def has_method(self, class_name: str, method_name: str) -> bool | None:
229+
def check_binary_method_with_types(self, method_name: str, self_type: JustTypeRef, other_type: JustTypeRef) -> bool:
229230
"""
230-
Returns whether the given class has the given method, or None if we cant find the class.
231+
Checks if the class has a binary method compatible with the given types.
231232
"""
232-
if class_name in self._classes:
233-
return method_name in self._classes[class_name].methods
233+
vars: dict[ClassTypeVarRef, JustTypeRef] = {}
234+
if callable_decl := self._classes[self_type.name].methods.get(method_name):
235+
match callable_decl.signature:
236+
case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(
237+
vars, self_type
238+
) and other_arg_type.matches_just(vars, other_type):
239+
return True
240+
return False
241+
242+
def check_binary_method_with_self_type(self, method_name: str, self_type: JustTypeRef) -> JustTypeRef | None:
243+
"""
244+
Checks if the class has a binary method with the given name and self type. Returns the other type if it exists.
245+
"""
246+
vars: dict[ClassTypeVarRef, JustTypeRef] = {}
247+
if callable_decl := self._classes[self_type.name].methods.get(method_name):
248+
match callable_decl.signature:
249+
case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(vars, self_type):
250+
return other_arg_type.to_just(vars)
234251
return None
235252

253+
def check_binary_method_with_other_type(self, method_name: str, other_type: JustTypeRef) -> Iterable[JustTypeRef]:
254+
"""
255+
Returns the types which are compatible with the given binary method name and other type.
256+
"""
257+
for class_decl in self._classes.values():
258+
vars: dict[ClassTypeVarRef, JustTypeRef] = {}
259+
if callable_decl := class_decl.methods.get(method_name):
260+
match callable_decl.signature:
261+
case FunctionSignature((self_arg_type, other_arg_type)) if other_arg_type.matches_just(
262+
vars, other_type
263+
):
264+
yield self_arg_type.to_just(vars)
265+
236266
def get_class_decl(self, name: str) -> ClassDecl:
237267
return self._classes[name]
238268

@@ -300,6 +330,10 @@ def __str__(self) -> str:
300330
_RESOLVED_TYPEVARS: dict[ClassTypeVarRef, TypeVar] = {}
301331

302332

333+
class TypeVarError(RuntimeError):
334+
"""Error when trying to resolve a type variable that doesn't exist."""
335+
336+
303337
@dataclass(frozen=True)
304338
class ClassTypeVarRef:
305339
"""
@@ -309,9 +343,10 @@ class ClassTypeVarRef:
309343
name: str
310344
module: str
311345

312-
def to_just(self) -> JustTypeRef:
313-
msg = f"{self}: egglog does not support generic classes yet."
314-
raise NotImplementedError(msg)
346+
def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef:
347+
if vars is None or self not in vars:
348+
raise TypeVarError(f"Cannot convert {self} to just type")
349+
return vars[self]
315350

316351
def __str__(self) -> str:
317352
return str(self.to_type_var())
@@ -325,20 +360,39 @@ def from_type_var(cls, typevar: TypeVar) -> ClassTypeVarRef:
325360
def to_type_var(self) -> TypeVar:
326361
return _RESOLVED_TYPEVARS[self]
327362

363+
def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool:
364+
"""
365+
Checks if this type variable matches the given JustTypeRef, including type variables.
366+
"""
367+
if self in vars:
368+
return vars[self] == other
369+
vars[self] = other
370+
return True
371+
328372

329373
@dataclass(frozen=True)
330374
class TypeRefWithVars:
331375
name: str
332376
args: tuple[TypeOrVarRef, ...] = ()
333377

334-
def to_just(self) -> JustTypeRef:
335-
return JustTypeRef(self.name, tuple(a.to_just() for a in self.args))
378+
def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef:
379+
return JustTypeRef(self.name, tuple(a.to_just(vars) for a in self.args))
336380

337381
def __str__(self) -> str:
338382
if self.args:
339383
return f"{self.name}[{', '.join(str(a) for a in self.args)}]"
340384
return self.name
341385

386+
def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool:
387+
"""
388+
Checks if this type reference matches the given JustTypeRef, including type variables.
389+
"""
390+
return (
391+
self.name == other.name
392+
and len(self.args) == len(other.args)
393+
and all(a.matches_just(vars, b) for a, b in zip(self.args, other.args, strict=True))
394+
)
395+
342396

343397
TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars
344398

python/egglog/egraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@
115115
"__firstlineno__",
116116
"__static_attributes__",
117117
# Ignore all reflected binary method
118-
*REFLECTED_BINARY_METHODS.keys(),
118+
*(f"__r{m[2:]}" for m in NUMERIC_BINARY_METHODS),
119119
}
120120

121121

0 commit comments

Comments
 (0)