diff --git a/docs/changelog.md b/docs/changelog.md index ddb82166..998d26ab 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -9,6 +9,8 @@ _This project uses semantic versioning_ Also changes the representation to be an index into a list instead of the ID, making egglog programs more deterministic. - Prefix constant declerations and unbound variables to not shadow let variables - BREAKING: Remove `simplify` since it was removed upstream. You can manually replace it with an insert, run, then extract. +- Change how anonymous functions are converted to remove metaprogramming and lift only the unbound variables as args +- Add support for getting the "value" of a function type with `.eval()`, i.e. `assert UnstableFn(f).eval() == f`. ## 10.0.2 (2025-06-22) diff --git a/python/egglog/builtins.py b/python/egglog/builtins.py index b2c5e438..d361d25b 100644 --- a/python/egglog/builtins.py +++ b/python/egglog/builtins.py @@ -8,16 +8,18 @@ from collections.abc import Callable from fractions import Fraction from functools import partial, reduce +from inspect import signature from types import FunctionType, MethodType from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, cast, overload from typing_extensions import TypeVarTuple, Unpack -from .conversion import convert, converter, get_type_args +from egglog.declarations import TypedExprDecl + +from .conversion import convert, converter, get_type_args, resolve_literal from .declarations import * -from .egraph import BaseExpr, BuiltinExpr, expr_fact, function, get_current_ruleset, method -from .functionalize import functionalize -from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction +from .egraph import BaseExpr, BuiltinExpr, _add_default_rewrite_inner, expr_fact, function, get_current_ruleset, method +from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction, resolve_type_annotation_mutate from .thunk import Thunk if TYPE_CHECKING: @@ -1022,6 +1024,44 @@ def __init__(self, f, *partial) -> None: ... @method(egg_fn="unstable-app") def __call__(self, *args: Unpack[TS]) -> T: ... + @method(preserve=True) + def eval(self) -> Callable[[Unpack[TS]], T]: + """ + If this is a constructor, returns either the callable directly or a `functools.partial` function if args are provided. + """ + assert isinstance(self, RuntimeExpr) + match self.__egg_typed_expr__.expr: + case PartialCallDecl(CallDecl() as call): + fn, args = _deconstruct_call_decl(self.__egg_decls_thunk__, call) + if not args: + return fn + return partial(fn, *args) + msg = "UnstableFn can only be evaluated if it is a function or a partial application of a function." + raise BuiltinEvalError(msg) + + +def _deconstruct_call_decl( + decls_thunk: Callable[[], Declarations], call: CallDecl +) -> tuple[Callable, tuple[object, ...]]: + """ + Deconstructs a CallDecl into a runtime callable and its arguments. + """ + args = call.args + arg_exprs = tuple(RuntimeExpr(decls_thunk, Thunk.value(a)) for a in args) + egg_bound = ( + JustTypeRef(call.callable.class_name, call.bound_tp_params or ()) + if isinstance(call.callable, (ClassMethodRef, InitRef, ClassVariableRef)) + else None + ) + if isinstance(call.callable, InitRef): + return RuntimeClass( + decls_thunk, + TypeRefWithVars( + call.callable.class_name, + ), + ), arg_exprs + return RuntimeFunction(decls_thunk, Thunk.value(call.callable), egg_bound), arg_exprs + # Method Type is for builtins like __getitem__ converter(MethodType, UnstableFn, lambda m: UnstableFn(m.__func__, m.__self__)) @@ -1029,39 +1069,44 @@ def __call__(self, *args: Unpack[TS]) -> T: ... converter(partial, UnstableFn, lambda p: UnstableFn(p.func, *p.args)) -def _convert_function(a: FunctionType) -> UnstableFn: +def _convert_function(fn: FunctionType) -> UnstableFn: """ - Converts a function type to an unstable function + Converts a function type to an unstable function. This function will be an anon function in egglog. - Would just be UnstableFn(function(a)) but we have to look for any nonlocals and globals - which are runtime expressions with `var`s in them and add them as args to the function + Would just be UnstableFn(function(a)) but we have to account for unbound vars within the body. + + This means that we have to turn all of those unbound vars into args to the function, and then + partially apply them, alongside creating a default rewrite for the function. """ - # Update annotations of a to be the type we are trying to convert to - return_tp, *arg_tps = get_type_args() - a.__annotations__ = { - "return": return_tp, - # The first varnames should always be the arg names - **dict(zip(a.__code__.co_varnames, arg_tps, strict=False)), - } - # Modify name to make it unique - # a.__name__ = f"{a.__name__} {hash(a.__code__)}" - transformed_fn = functionalize(a, value_to_annotation) - assert isinstance(transformed_fn, partial) - return UnstableFn( - function(ruleset=get_current_ruleset(), use_body_as_name=True, subsume=True)(transformed_fn.func), - *transformed_fn.args, + decls = Declarations() + return_type, *arg_types = [resolve_type_annotation_mutate(decls, tp) for tp in get_type_args()] + arg_names = [p.name for p in signature(fn).parameters.values()] + arg_decls = [ + TypedExprDecl(tp.to_just(), UnboundVarDecl(name)) for name, tp in zip(arg_names, arg_types, strict=True) + ] + res = resolve_literal( + return_type, fn(*(RuntimeExpr.__from_values__(decls, a) for a in arg_decls)), Thunk.value(decls) ) + res_expr = res.__egg_typed_expr__ + decls |= res + # these are all the args that appear in the body that are not bound by the args of the function + unbound_vars = list(collect_unbound_vars(res_expr) - set(arg_decls)) + # prefix the args with them + fn_ref = UnnamedFunctionRef(tuple(unbound_vars + arg_decls), res_expr) + rewrite_decl = DefaultRewriteDecl(fn_ref, res_expr.expr, subsume=True) + ruleset_decls = _add_default_rewrite_inner(decls, rewrite_decl, get_current_ruleset()) + ruleset_decls |= res - -def value_to_annotation(a: object) -> type | None: - # only lift runtime expressions (which could contain vars) not any other nonlocals/globals we use in the function - if not isinstance(a, RuntimeExpr): - return None - return cast("type", RuntimeClass(Thunk.value(a.__egg_decls__), a.__egg_typed_expr__.tp.to_var())) + fn = RuntimeFunction(Thunk.value(decls), Thunk.value(fn_ref)) + return UnstableFn(fn, *(RuntimeExpr.__from_values__(decls, v) for v in unbound_vars)) converter(FunctionType, UnstableFn, _convert_function) +## +# Utility Functions +## + def _extract_lit(e: BaseExpr) -> LitType: """ diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index c7a0bd65..6725ccd2 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -77,6 +77,7 @@ "UnboundVarDecl", "UnionDecl", "UnnamedFunctionRef", + "collect_unbound_vars", "replace_typed_expr", "upcast_declerations", ] @@ -683,6 +684,28 @@ def _inner(typed_expr: TypedExprDecl) -> TypedExprDecl: return _inner(typed_expr) +def collect_unbound_vars(typed_expr: TypedExprDecl) -> set[TypedExprDecl]: + """ + Returns the set of all unbound vars + """ + seen = set[TypedExprDecl]() + unbound_vars = set[TypedExprDecl]() + + def visit(typed_expr: TypedExprDecl) -> None: + if typed_expr in seen: + return + seen.add(typed_expr) + match typed_expr.expr: + case CallDecl(_, args) | PartialCallDecl(CallDecl(_, args)): + for arg in args: + visit(arg) + case UnboundVarDecl(_): + unbound_vars.add(typed_expr) + + visit(typed_expr) + return unbound_vars + + ## # Schedules ## diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 520baeaa..90df1078 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -29,6 +29,7 @@ from . import bindings from .conversion import * +from .conversion import convert_to_same_type, resolve_literal from .declarations import * from .egraph_state import * from .ipython_magic import IN_IPYTHON @@ -281,7 +282,6 @@ def function( mutates_first_arg: bool = ..., unextractable: bool = ..., ruleset: Ruleset | None = ..., - use_body_as_name: bool = ..., subsume: bool = ..., ) -> Callable[[CONSTRUCTOR_CALLABLE], CONSTRUCTOR_CALLABLE]: ... @@ -467,7 +467,7 @@ def _generate_class_decls( # noqa: C901,PLR0912 decls.set_function_decl(ref, decl) continue try: - _, add_rewrite = _fn_decl( + add_rewrite = _fn_decl( decls, egg_fn, ref, @@ -505,7 +505,6 @@ class _FunctionConstructor: merge: Callable[[object, object], object] | None = None unextractable: bool = False ruleset: Ruleset | None = None - use_body_as_name: bool = False subsume: bool = False def __call__(self, fn: Callable) -> RuntimeFunction: @@ -513,11 +512,10 @@ def __call__(self, fn: Callable) -> RuntimeFunction: def create_decls(self, fn: Callable) -> tuple[Declarations, CallableRef]: decls = Declarations() - ref = None if self.use_body_as_name else FunctionRef(fn.__name__) - ref, add_rewrite = _fn_decl( + add_rewrite = _fn_decl( decls, self.egg_fn, - ref, + ref := FunctionRef(fn.__name__), fn, self.hint_locals, self.cost, @@ -535,8 +533,7 @@ def create_decls(self, fn: Callable) -> tuple[Declarations, CallableRef]: def _fn_decl( decls: Declarations, egg_name: str | None, - # If ref is Callable, then generate the ref from the function name - ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef | None, + ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef, fn: object, # Pass in the locals, retrieved from the frame when wrapping, # so that we support classes and function defined inside of other functions (which won't show up in the globals) @@ -549,7 +546,7 @@ def _fn_decl( ruleset: Ruleset | None = None, unextractable: bool = False, reverse_args: bool = False, -) -> tuple[CallableRef, Callable[[], None]]: +) -> Callable[[], None]: """ Sets the function decl for the function object and returns the ref as well as a thunk that sets the default callable. """ @@ -619,50 +616,39 @@ def _fn_decl( # defer this in generator so it doesn't resolve for builtins eagerly args = (TypedExprDecl(tp.to_just(), UnboundVarDecl(name)) for name, tp in zip(arg_names, arg_types, strict=True)) - res_ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef - res_thunk: Callable[[], object] - # If we were not passed in a ref, this is an unnamed funciton, so eagerly compute the value and use that to refer to it - if not ref: - tuple_args = tuple(args) - res = _create_default_value(decls, ref, fn, tuple_args, ruleset) - assert isinstance(res, RuntimeExpr) - res_ref = UnnamedFunctionRef(tuple_args, res.__egg_typed_expr__) - decls._unnamed_functions.add(res_ref) - res_thunk = Thunk.value(res) + return_type_is_eqsort = ( + not decls._classes[return_type.name].builtin if isinstance(return_type, TypeRefWithVars) else False + ) + is_constructor = not is_builtin and return_type_is_eqsort and merged is None + signature_ = FunctionSignature( + return_type=None if mutates_first_arg else return_type, + var_arg_type=var_arg_type, + arg_types=arg_types, + arg_names=arg_names, + arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults), + reverse_args=reverse_args, + ) + decl: ConstructorDecl | FunctionDecl + if is_constructor: + decl = ConstructorDecl(signature_, egg_name, cost, unextractable) else: - return_type_is_eqsort = ( - not decls._classes[return_type.name].builtin if isinstance(return_type, TypeRefWithVars) else False - ) - is_constructor = not is_builtin and return_type_is_eqsort and merged is None - signature_ = FunctionSignature( - return_type=None if mutates_first_arg else return_type, - var_arg_type=var_arg_type, - arg_types=arg_types, - arg_names=arg_names, - arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults), - reverse_args=reverse_args, + if cost is not None: + msg = "Cost can only be set for constructors" + raise ValueError(msg) + if unextractable: + msg = "Unextractable can only be set for constructors" + raise ValueError(msg) + decl = FunctionDecl( + signature=signature_, + egg_name=egg_name, + merge=merged.__egg_typed_expr__.expr if merged is not None else None, + builtin=is_builtin, ) - decl: ConstructorDecl | FunctionDecl - if is_constructor: - decl = ConstructorDecl(signature_, egg_name, cost, unextractable) - else: - if cost is not None: - msg = "Cost can only be set for constructors" - raise ValueError(msg) - if unextractable: - msg = "Unextractable can only be set for constructors" - raise ValueError(msg) - decl = FunctionDecl( - signature=signature_, - egg_name=egg_name, - merge=merged.__egg_typed_expr__.expr if merged is not None else None, - builtin=is_builtin, - ) - res_ref = ref - decls.set_function_decl(ref, decl) - res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset, context=f"creating {ref}") - return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk, subsume) + decls.set_function_decl(ref, decl) + return Thunk.fn( + _add_default_rewrite_function, decls, ref, fn, args, ruleset, subsume, return_type, context=f"creating {ref}" + ) # Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value @@ -736,13 +722,15 @@ def _constant_thunk( return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref)) -def _create_default_value( +def _add_default_rewrite_function( decls: Declarations, - ref: CallableRef | None, + ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef, fn: Callable, args: Iterable[TypedExprDecl], ruleset: Ruleset | None, -) -> object: + subsume: bool, + res_type: TypeOrVarRef, +) -> None: args: list[object] = [RuntimeExpr.__from_values__(decls, a) for a in args] # If this is a classmethod, add the class as the first arg @@ -750,21 +738,8 @@ def _create_default_value( tp = decls.get_paramaterized_class(ref.class_name) args.insert(0, RuntimeClass(Thunk.value(decls), tp)) with set_current_ruleset(ruleset): - return fn(*args) - - -def _add_default_rewrite_function( - decls: Declarations, - ref: CallableRef, - res_type: TypeOrVarRef, - ruleset: Ruleset | None, - value_thunk: Callable[[], object], - subsume: bool, -) -> None: - """ - Helper functions that resolves a value thunk to create the default value. - """ - _add_default_rewrite(decls, ref, res_type, value_thunk(), ruleset, subsume) + res = fn(*args) + _add_default_rewrite(decls, ref, res_type, res, ruleset, subsume) def _add_default_rewrite( @@ -784,6 +759,13 @@ def _add_default_rewrite( return resolved_value = resolve_literal(type_ref, default_rewrite, Thunk.value(decls)) rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr, subsume) + ruleset_decls = _add_default_rewrite_inner(decls, rewrite_decl, ruleset) + ruleset_decls |= resolved_value + + +def _add_default_rewrite_inner( + decls: Declarations, rewrite_decl: DefaultRewriteDecl, ruleset: Ruleset | None +) -> Declarations: if ruleset: ruleset_decls = ruleset._current_egg_decls ruleset_decl = ruleset.__egg_ruleset__ @@ -791,7 +773,7 @@ def _add_default_rewrite( ruleset_decls = decls ruleset_decl = decls.default_ruleset ruleset_decl.rules.append(rewrite_decl) - ruleset_decls |= resolved_value + return ruleset_decls def _last_param_variable(params: list[Parameter]) -> bool: diff --git a/python/egglog/functionalize.py b/python/egglog/functionalize.py deleted file mode 100644 index 133048f6..00000000 --- a/python/egglog/functionalize.py +++ /dev/null @@ -1,91 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable -from functools import partial -from inspect import Parameter, signature -from typing import Any, TypeVar, cast - -__all__ = ["functionalize"] - - -T = TypeVar("T", bound=Callable) - - -# TODO: Add `to_lift` param so that we only transform those with vars in them to args - - -def functionalize(f: T, get_annotation: Callable[[object], type | None]) -> T: - """ - Takes a function and returns a new function with all names (co_names) and free variables (co_freevars) added as arguments - and then partially applied with their values. The second arg, get_annotation, will be applied to all values - to get their type annotation. If it is None, that arg will not be added as a parameter. - - For example if you have: - - def get_annotation(x): return int if x <= 10 else None - - g = 10 - def f(a, a2): - def h(b: Z): - return a + a2 + b + g - - return functionalize(h, get_annotation) - res = f(9, 11) - - It should be equivalent to (according to body, signature, and annotations) (Note that the new arguments will be positional only): - - def h(a: get_annotation(a), g: get_annotation(g), b: Z): - return a + b + g - res = partial(h, a, g) - """ - code = f.__code__ - names = tuple(n for n in code.co_names if n in f.__globals__) - free_vars = code.co_freevars - - global_values: list[Any] = [f.__globals__[name] for name in names] - free_var_values = [cell.cell_contents for cell in f.__closure__] if f.__closure__ else [] - assert len(free_var_values) == len(free_vars), "Free vars and their values do not match" - global_values_filtered = [ - (i, name, value, annotation) - for i, (name, value) in enumerate(zip(names, global_values, strict=True)) - if (annotation := get_annotation(value)) is not None - ] - free_var_values_filtered = [ - (i, name, value, annotation) - for i, (name, value) in enumerate(zip(free_vars, free_var_values, strict=True)) - if (annotation := get_annotation(value)) is not None - ] - additional_arg_filtered = global_values_filtered + free_var_values_filtered - - # Create a wrapper function - def wrapper(*args): - # Split args into names, free vars and other args - name_args, free_var_args, rest_args = ( - args[: (n_names := len(global_values_filtered))], - args[n_names : (n_args := len(additional_arg_filtered))], - args[n_args:], - ) - # Update globals with names - f.__globals__.update({ - name: arg for (_, name, _, _), arg in zip(global_values_filtered, name_args, strict=False) - }) - # update function free vars with free var args - for (i, _, _, _), value in zip(free_var_values_filtered, free_var_args, strict=True): - assert f.__closure__, "Function does not have closure" - f.__closure__[i].cell_contents = value - return f(*rest_args) - - # Set the signature of the new function to a signature with the free vars and names added as arguments - orig_signature = signature(f) - wrapper.__signature__ = orig_signature.replace( # type: ignore[attr-defined] - parameters=[ - *[Parameter(n, Parameter.POSITIONAL_OR_KEYWORD) for _, n, _, _ in additional_arg_filtered], - *orig_signature.parameters.values(), - ] - ) - # Set the annotations of the new function to the annotations of the original function + annotations of passed in values - wrapper.__annotations__ = f.__annotations__ | {n: a for _, n, _, a in additional_arg_filtered} - wrapper.__name__ = f.__name__ - - # Partially apply the wrapper function with the current values of the free vars - return cast("T", partial(wrapper, *(v for _, _, v, _ in additional_arg_filtered))) diff --git a/python/egglog/runtime.py b/python/egglog/runtime.py index 1963def9..487c3863 100644 --- a/python/egglog/runtime.py +++ b/python/egglog/runtime.py @@ -306,6 +306,17 @@ class RuntimeFunction(DelayedDeclerations): # bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self __egg_bound__: JustTypeRef | RuntimeExpr | None = None + def __eq__(self, other: object) -> bool: + """ + Support equality for runtime comparison of egglog functions. + """ + if not isinstance(other, RuntimeFunction): + return NotImplemented + return self.__egg_ref__ == other.__egg_ref__ and bool(self.__egg_bound__ == other.__egg_bound__) + + def __hash__(self) -> int: + return hash((self.__egg_ref__, self.__egg_bound__)) + @property def __egg_ref__(self) -> CallableRef: return self.__egg_ref_thunk__() diff --git a/python/tests/test_functionalize.py b/python/tests/test_functionalize.py deleted file mode 100644 index c9b04a6d..00000000 --- a/python/tests/test_functionalize.py +++ /dev/null @@ -1,61 +0,0 @@ -from collections.abc import Callable -from functools import partial -from inspect import signature -from typing import get_type_hints - -from egglog.functionalize import functionalize - - -def get_annotation(x: object) -> type | None: - return type(x) if x != 1 else None - - -x = "x" - - -def outer(y1: str, y2: int) -> Callable[[str, int], tuple[str, str, int, str, int]]: - def inner(z1: str, z2: int) -> tuple[str, str, int, str, int]: - return (x, y1, y2, z1, z2) - - return functionalize(inner, get_annotation) - - -res = outer("y1", 1) - - -def test_partial(): - assert isinstance(res, partial) - assert res.args == ("x", "y1") - - -def test_signature(): - assert isinstance(res, partial) - - sig = signature(res.func) - assert list(sig.parameters) == ["x", "y1", "z1", "z2"] - - -def test_annotations(): - assert isinstance(res, partial) - - annotations = get_type_hints(res.func) - assert annotations == { - "x": str, - "y1": str, - "z1": str, - "z2": int, - "return": tuple[str, str, int, str, int], - } - - -def test_call(): - assert res("z1", 2) == ("x", "y1", 1, "z1", 2) - - -def test_call_again(): - assert res("z1_", 22) == ("x", "y1", 1, "z1_", 22) - - -def test_name(): - assert isinstance(res, partial) - assert res.func.__name__ == "inner" diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 974e8600..e8aa1c34 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -5,6 +5,7 @@ import pathlib from copy import copy from fractions import Fraction +from functools import partial from typing import ClassVar, TypeAlias, TypeVar import pytest @@ -526,6 +527,21 @@ def test_big_rat(self): def test_multiset(self): assert list(MultiSet(i64(1), i64(1))) == [i64(1), i64(1)] + def test_unstable_fn(self): + class Math(Expr): + def __init__(self) -> None: ... + + @function + def f(x: Math) -> Math: ... + + u_f = UnstableFn(f) + assert u_f.eval() == f + p_u_f = UnstableFn(f, Math()) + value = p_u_f.eval() + assert isinstance(value, partial) + assert value.func == f + assert value.args == (Math(),) + # def test_egglog_string(): # egraph = EGraph(save_egglog_string=True)