Skip to content

Commit 2ccaad7

Browse files
Add support for adding methods actually defined on classes
Adds support for methods like `__array_function__` which actually need to be added on the class as actual methods, not through overloading `__getattr__`. Custom methods can be registered by third party libraries. This PR also redoes the logic for upcasting when using binary operations. Instead of upcasting both values, it will only ever upcast one, choosing whichever one would be cheaper to upcast. This leads to more predictable behavior.
1 parent 4a43cd8 commit 2ccaad7

File tree

14 files changed

+468
-495
lines changed

14 files changed

+468
-495
lines changed

docs/changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ _This project uses semantic versioning_
99
Also changes the representation to be an index into a list instead of the ID, making egglog programs more deterministic.
1010
- Prefix constant declerations and unbound variables to not shadow let variables
1111
- BREAKING: Remove `simplify` since it was removed upstream. You can manually replace it with an insert, run, then extract.
12+
- Change how anonymous functions are converted to remove metaprogramming and lift only the unbound variables as args
13+
- Add support for getting the "value" of a function type with `.eval()`, i.e. `assert UnstableFn(f).eval() == f`.
1214

1315
## 10.0.2 (2025-06-22)
1416

docs/reference/python-integration.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,11 @@ Note that the following list of methods are only supported as "preserved" since
303303
- `__iter_`
304304
- `__index__`
305305

306+
If you want to register additional methods as always preserved and defined on the `Expr` class itself, if needed
307+
instead of the normal mechanism which relies on `__getattr__`, you can call `egglog.define_expr_method(name: str)`,
308+
with the name of a method. This is only needed for third party code that inspects the type object itself to see if a
309+
method is defined instead of just attempting to call it.
310+
306311
### Reflected methods
307312

308313
Note that reflected methods (i.e. `__radd__`) are handled as a special case. If defined, they won't create their own egglog functions.

python/egglog/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from . import config, ipython_magic # noqa: F401
66
from .bindings import EggSmolError # noqa: F401
77
from .builtins import * # noqa: UP029
8-
from .conversion import ConvertError, convert, converter, get_type_args # noqa: F401
8+
from .conversion import *
99
from .egraph import *
10+
from .runtime import define_expr_method as define_expr_method # noqa: PLC0414
1011

1112
del ipython_magic

python/egglog/builtins.py

Lines changed: 73 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,18 @@
88
from collections.abc import Callable
99
from fractions import Fraction
1010
from functools import partial, reduce
11+
from inspect import signature
1112
from types import FunctionType, MethodType
1213
from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, cast, overload
1314

1415
from typing_extensions import TypeVarTuple, Unpack
1516

16-
from .conversion import convert, converter, get_type_args
17+
from egglog.declarations import TypedExprDecl
18+
19+
from .conversion import convert, converter, get_type_args, resolve_literal
1720
from .declarations import *
18-
from .egraph import BaseExpr, BuiltinExpr, expr_fact, function, get_current_ruleset, method
19-
from .functionalize import functionalize
20-
from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
21+
from .egraph import BaseExpr, BuiltinExpr, _add_default_rewrite_inner, expr_fact, function, get_current_ruleset, method
22+
from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction, resolve_type_annotation_mutate
2123
from .thunk import Thunk
2224

2325
if TYPE_CHECKING:
@@ -1022,46 +1024,89 @@ def __init__(self, f, *partial) -> None: ...
10221024
@method(egg_fn="unstable-app")
10231025
def __call__(self, *args: Unpack[TS]) -> T: ...
10241026

1027+
@method(preserve=True)
1028+
def eval(self) -> Callable[[Unpack[TS]], T]:
1029+
"""
1030+
If this is a constructor, returns either the callable directly or a `functools.partial` function if args are provided.
1031+
"""
1032+
assert isinstance(self, RuntimeExpr)
1033+
match self.__egg_typed_expr__.expr:
1034+
case PartialCallDecl(CallDecl() as call):
1035+
fn, args = _deconstruct_call_decl(self.__egg_decls_thunk__, call)
1036+
if not args:
1037+
return fn
1038+
return partial(fn, *args)
1039+
msg = "UnstableFn can only be evaluated if it is a function or a partial application of a function."
1040+
raise BuiltinEvalError(msg)
1041+
1042+
1043+
def _deconstruct_call_decl(
1044+
decls_thunk: Callable[[], Declarations], call: CallDecl
1045+
) -> tuple[Callable, tuple[object, ...]]:
1046+
"""
1047+
Deconstructs a CallDecl into a runtime callable and its arguments.
1048+
"""
1049+
args = call.args
1050+
arg_exprs = tuple(RuntimeExpr(decls_thunk, Thunk.value(a)) for a in args)
1051+
egg_bound = (
1052+
JustTypeRef(call.callable.class_name, call.bound_tp_params or ())
1053+
if isinstance(call.callable, (ClassMethodRef, InitRef, ClassVariableRef))
1054+
else None
1055+
)
1056+
if isinstance(call.callable, InitRef):
1057+
return RuntimeClass(
1058+
decls_thunk,
1059+
TypeRefWithVars(
1060+
call.callable.class_name,
1061+
),
1062+
), arg_exprs
1063+
return RuntimeFunction(decls_thunk, Thunk.value(call.callable), egg_bound), arg_exprs
1064+
10251065

10261066
# Method Type is for builtins like __getitem__
10271067
converter(MethodType, UnstableFn, lambda m: UnstableFn(m.__func__, m.__self__))
10281068
converter(RuntimeFunction, UnstableFn, UnstableFn)
10291069
converter(partial, UnstableFn, lambda p: UnstableFn(p.func, *p.args))
10301070

10311071

1032-
def _convert_function(a: FunctionType) -> UnstableFn:
1072+
def _convert_function(fn: FunctionType) -> UnstableFn:
10331073
"""
1034-
Converts a function type to an unstable function
1074+
Converts a function type to an unstable function. This function will be an anon function in egglog.
10351075
1036-
Would just be UnstableFn(function(a)) but we have to look for any nonlocals and globals
1037-
which are runtime expressions with `var`s in them and add them as args to the function
1076+
Would just be UnstableFn(function(a)) but we have to account for unbound vars within the body.
1077+
1078+
This means that we have to turn all of those unbound vars into args to the function, and then
1079+
partially apply them, alongside creating a default rewrite for the function.
10381080
"""
1039-
# Update annotations of a to be the type we are trying to convert to
1040-
return_tp, *arg_tps = get_type_args()
1041-
a.__annotations__ = {
1042-
"return": return_tp,
1043-
# The first varnames should always be the arg names
1044-
**dict(zip(a.__code__.co_varnames, arg_tps, strict=False)),
1045-
}
1046-
# Modify name to make it unique
1047-
# a.__name__ = f"{a.__name__} {hash(a.__code__)}"
1048-
transformed_fn = functionalize(a, value_to_annotation)
1049-
assert isinstance(transformed_fn, partial)
1050-
return UnstableFn(
1051-
function(ruleset=get_current_ruleset(), use_body_as_name=True, subsume=True)(transformed_fn.func),
1052-
*transformed_fn.args,
1081+
decls = Declarations()
1082+
return_type, *arg_types = [resolve_type_annotation_mutate(decls, tp) for tp in get_type_args()]
1083+
arg_names = [p.name for p in signature(fn).parameters.values()]
1084+
arg_decls = [
1085+
TypedExprDecl(tp.to_just(), UnboundVarDecl(name)) for name, tp in zip(arg_names, arg_types, strict=True)
1086+
]
1087+
res = resolve_literal(
1088+
return_type, fn(*(RuntimeExpr.__from_values__(decls, a) for a in arg_decls)), Thunk.value(decls)
10531089
)
1090+
res_expr = res.__egg_typed_expr__
1091+
decls |= res
1092+
# these are all the args that appear in the body that are not bound by the args of the function
1093+
unbound_vars = list(collect_unbound_vars(res_expr) - set(arg_decls))
1094+
# prefix the args with them
1095+
fn_ref = UnnamedFunctionRef(tuple(unbound_vars + arg_decls), res_expr)
1096+
rewrite_decl = DefaultRewriteDecl(fn_ref, res_expr.expr, subsume=True)
1097+
ruleset_decls = _add_default_rewrite_inner(decls, rewrite_decl, get_current_ruleset())
1098+
ruleset_decls |= res
10541099

1055-
1056-
def value_to_annotation(a: object) -> type | None:
1057-
# only lift runtime expressions (which could contain vars) not any other nonlocals/globals we use in the function
1058-
if not isinstance(a, RuntimeExpr):
1059-
return None
1060-
return cast("type", RuntimeClass(Thunk.value(a.__egg_decls__), a.__egg_typed_expr__.tp.to_var()))
1100+
fn = RuntimeFunction(Thunk.value(decls), Thunk.value(fn_ref))
1101+
return UnstableFn(fn, *(RuntimeExpr.__from_values__(decls, v) for v in unbound_vars))
10611102

10621103

10631104
converter(FunctionType, UnstableFn, _convert_function)
10641105

1106+
##
1107+
# Utility Functions
1108+
##
1109+
10651110

10661111
def _extract_lit(e: BaseExpr) -> LitType:
10671112
"""

python/egglog/conversion.py

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

33
from collections import defaultdict
4+
from collections.abc import Callable
45
from contextlib import contextmanager
56
from contextvars import ContextVar
67
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, TypeVar, cast
8+
from typing import TYPE_CHECKING, Any, TypeVar, cast
89

910
from .declarations import *
1011
from .pretty import *
@@ -13,22 +14,22 @@
1314
from .type_constraint_solver import TypeConstraintError
1415

1516
if TYPE_CHECKING:
16-
from collections.abc import Callable, Generator
17+
from collections.abc import Generator
1718

1819
from .egraph import BaseExpr
1920
from .type_constraint_solver import TypeConstraintSolver
2021

21-
__all__ = ["ConvertError", "convert", "convert_to_same_type", "converter", "resolve_literal"]
22+
__all__ = ["ConvertError", "convert", "converter", "get_type_args"]
2223
# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
23-
CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {}
24+
CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable[[Any], RuntimeExpr]]] = {}
2425
# Global declerations to store all convertable types so we can query if they have certain methods or not
2526
_CONVERSION_DECLS = Declarations.create()
2627
# Defer a list of declerations to be added to the global declerations, so that we can not trigger them procesing
2728
# until we need them
2829
_TO_PROCESS_DECLS: list[DeclerationsLike] = []
2930

3031

31-
def _retrieve_conversion_decls() -> Declarations:
32+
def retrieve_conversion_decls() -> Declarations:
3233
_CONVERSION_DECLS.update(*_TO_PROCESS_DECLS)
3334
_TO_PROCESS_DECLS.clear()
3435
return _CONVERSION_DECLS
@@ -49,10 +50,10 @@ def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost:
4950
to_type_name = process_tp(to_type)
5051
if not isinstance(to_type_name, JustTypeRef):
5152
raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
52-
_register_converter(process_tp(from_type), to_type_name, fn, cost)
53+
_register_converter(process_tp(from_type), to_type_name, cast("Callable[[Any], RuntimeExpr]", fn), cost)
5354

5455

55-
def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None:
56+
def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable[[Any], RuntimeExpr], cost: int) -> None:
5657
"""
5758
Registers a converter from some type to an egglog type, if not already registered.
5859
@@ -97,15 +98,15 @@ class _ComposedConverter:
9798
We use the dataclass instead of the lambda to make it easier to debug.
9899
"""
99100

100-
a_b: Callable
101-
b_c: Callable
101+
a_b: Callable[[Any], RuntimeExpr]
102+
b_c: Callable[[Any], RuntimeExpr]
102103
b_args: tuple[JustTypeRef, ...]
103104

104-
def __call__(self, x: object) -> object:
105+
def __call__(self, x: Any) -> RuntimeExpr:
105106
# if we have A -> B and B[C] -> D then we should use (C,) as the type args
106107
# when converting from A -> B
107108
if self.b_args:
108-
with with_type_args(self.b_args, _retrieve_conversion_decls):
109+
with with_type_args(self.b_args, retrieve_conversion_decls):
109110
first_res = self.a_b(x)
110111
else:
111112
first_res = self.a_b(x)
@@ -142,33 +143,38 @@ def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type:
142143
return tp
143144

144145

145-
def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
146-
"""
147-
Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
148-
"""
149-
decls = _retrieve_conversion_decls()
150-
a_tp = _get_tp(a)
151-
b_tp = _get_tp(b)
152-
# Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
153-
if not (
154-
(isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
155-
or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
156-
):
157-
raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
158-
a_converts_to = {
159-
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
160-
}
161-
b_converts_to = {
162-
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
163-
}
164-
if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
165-
a_converts_to[a_tp] = 0
166-
if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
167-
b_converts_to[b_tp] = 0
168-
common = set(a_converts_to) & set(b_converts_to)
169-
if not common:
170-
raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
171-
return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
146+
# def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
147+
# """
148+
# Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
149+
# """
150+
# decls = _retrieve_conversion_decls().copy()
151+
# if isinstance(a, RuntimeExpr):
152+
# decls |= a
153+
# if isinstance(b, RuntimeExpr):
154+
# decls |= b
155+
156+
# a_tp = _get_tp(a)
157+
# b_tp = _get_tp(b)
158+
# # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
159+
# if not (
160+
# (isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
161+
# or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
162+
# ):
163+
# raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
164+
# a_converts_to = {
165+
# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
166+
# }
167+
# b_converts_to = {
168+
# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
169+
# }
170+
# if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
171+
# a_converts_to[a_tp] = 0
172+
# if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
173+
# b_converts_to[b_tp] = 0
174+
# common = set(a_converts_to) & set(b_converts_to)
175+
# if not common:
176+
# raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
177+
# return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
172178

173179

174180
def identity(x: object) -> object:
@@ -197,7 +203,7 @@ def with_type_args(args: tuple[JustTypeRef, ...], decls: Callable[[], Declaratio
197203
def resolve_literal(
198204
tp: TypeOrVarRef,
199205
arg: object,
200-
decls: Callable[[], Declarations] = _retrieve_conversion_decls,
206+
decls: Callable[[], Declarations] = retrieve_conversion_decls,
201207
tcs: TypeConstraintSolver | None = None,
202208
cls_name: str | None = None,
203209
) -> RuntimeExpr:
@@ -208,12 +214,12 @@ def resolve_literal(
208214
209215
If it cannot be resolved, we assume that the value passed in will resolve it.
210216
"""
211-
arg_type = _get_tp(arg)
217+
arg_type = resolve_type(arg)
212218

213219
# If we have any type variables, dont bother trying to resolve the literal, just return the arg
214220
try:
215221
tp_just = tp.to_just()
216-
except NotImplementedError:
222+
except TypeVarError:
217223
# If this is a generic arg but passed in a non runtime expression, try to resolve the generic
218224
# args first based on the existing type constraint solver
219225
if tcs:
@@ -258,7 +264,7 @@ def _debug_print_converers():
258264
source_to_targets[source].append(target)
259265

260266

261-
def _get_tp(x: object) -> JustTypeRef | type:
267+
def resolve_type(x: object) -> JustTypeRef | type:
262268
if isinstance(x, RuntimeExpr):
263269
return x.__egg_typed_expr__.tp
264270
tp = type(x)

0 commit comments

Comments
 (0)