Skip to content

Commit 00d1b65

Browse files
Simplify how anonymous functions are parsed
Updates the way that we handle passing higher order functions into egglog to remove some previous metaprogramming magic and replace it with a simpler algorithm that just removes any unbound variables left in a function and hoists them as args to the function.
1 parent 4a43cd8 commit 00d1b65

File tree

8 files changed

+176
-249
lines changed

8 files changed

+176
-249
lines changed

docs/changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ _This project uses semantic versioning_
99
Also changes the representation to be an index into a list instead of the ID, making egglog programs more deterministic.
1010
- Prefix constant declerations and unbound variables to not shadow let variables
1111
- BREAKING: Remove `simplify` since it was removed upstream. You can manually replace it with an insert, run, then extract.
12+
- Change how anonymous functions are converted to remove metaprogramming and lift only the unbound variables as args
13+
- Add support for getting the "value" of a function type with `.eval()`, i.e. `assert UnstableFn(f).eval() == f`.
1214

1315
## 10.0.2 (2025-06-22)
1416

python/egglog/builtins.py

Lines changed: 73 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,18 @@
88
from collections.abc import Callable
99
from fractions import Fraction
1010
from functools import partial, reduce
11+
from inspect import signature
1112
from types import FunctionType, MethodType
1213
from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, cast, overload
1314

1415
from typing_extensions import TypeVarTuple, Unpack
1516

16-
from .conversion import convert, converter, get_type_args
17+
from egglog.declarations import TypedExprDecl
18+
19+
from .conversion import convert, converter, get_type_args, resolve_literal
1720
from .declarations import *
18-
from .egraph import BaseExpr, BuiltinExpr, expr_fact, function, get_current_ruleset, method
19-
from .functionalize import functionalize
20-
from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
21+
from .egraph import BaseExpr, BuiltinExpr, _add_default_rewrite_inner, expr_fact, function, get_current_ruleset, method
22+
from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction, resolve_type_annotation_mutate
2123
from .thunk import Thunk
2224

2325
if TYPE_CHECKING:
@@ -1022,46 +1024,89 @@ def __init__(self, f, *partial) -> None: ...
10221024
@method(egg_fn="unstable-app")
10231025
def __call__(self, *args: Unpack[TS]) -> T: ...
10241026

1027+
@method(preserve=True)
1028+
def eval(self) -> Callable[[Unpack[TS]], T]:
1029+
"""
1030+
If this is a constructor, returns either the callable directly or a `functools.partial` function if args are provided.
1031+
"""
1032+
assert isinstance(self, RuntimeExpr)
1033+
match self.__egg_typed_expr__.expr:
1034+
case PartialCallDecl(CallDecl() as call):
1035+
fn, args = _deconstruct_call_decl(self.__egg_decls_thunk__, call)
1036+
if not args:
1037+
return fn
1038+
return partial(fn, *args)
1039+
msg = "UnstableFn can only be evaluated if it is a function or a partial application of a function."
1040+
raise BuiltinEvalError(msg)
1041+
1042+
1043+
def _deconstruct_call_decl(
1044+
decls_thunk: Callable[[], Declarations], call: CallDecl
1045+
) -> tuple[Callable, tuple[object, ...]]:
1046+
"""
1047+
Deconstructs a CallDecl into a runtime callable and its arguments.
1048+
"""
1049+
args = call.args
1050+
arg_exprs = tuple(RuntimeExpr(decls_thunk, Thunk.value(a)) for a in args)
1051+
egg_bound = (
1052+
JustTypeRef(call.callable.class_name, call.bound_tp_params or ())
1053+
if isinstance(call.callable, (ClassMethodRef, InitRef, ClassVariableRef))
1054+
else None
1055+
)
1056+
if isinstance(call.callable, InitRef):
1057+
return RuntimeClass(
1058+
decls_thunk,
1059+
TypeRefWithVars(
1060+
call.callable.class_name,
1061+
),
1062+
), arg_exprs
1063+
return RuntimeFunction(decls_thunk, Thunk.value(call.callable), egg_bound), arg_exprs
1064+
10251065

10261066
# Method Type is for builtins like __getitem__
10271067
converter(MethodType, UnstableFn, lambda m: UnstableFn(m.__func__, m.__self__))
10281068
converter(RuntimeFunction, UnstableFn, UnstableFn)
10291069
converter(partial, UnstableFn, lambda p: UnstableFn(p.func, *p.args))
10301070

10311071

1032-
def _convert_function(a: FunctionType) -> UnstableFn:
1072+
def _convert_function(fn: FunctionType) -> UnstableFn:
10331073
"""
1034-
Converts a function type to an unstable function
1074+
Converts a function type to an unstable function. This function will be an anon function in egglog.
10351075
1036-
Would just be UnstableFn(function(a)) but we have to look for any nonlocals and globals
1037-
which are runtime expressions with `var`s in them and add them as args to the function
1076+
Would just be UnstableFn(function(a)) but we have to account for unbound vars within the body.
1077+
1078+
This means that we have to turn all of those unbound vars into args to the function, and then
1079+
partially apply them, alongside creating a default rewrite for the function.
10381080
"""
1039-
# Update annotations of a to be the type we are trying to convert to
1040-
return_tp, *arg_tps = get_type_args()
1041-
a.__annotations__ = {
1042-
"return": return_tp,
1043-
# The first varnames should always be the arg names
1044-
**dict(zip(a.__code__.co_varnames, arg_tps, strict=False)),
1045-
}
1046-
# Modify name to make it unique
1047-
# a.__name__ = f"{a.__name__} {hash(a.__code__)}"
1048-
transformed_fn = functionalize(a, value_to_annotation)
1049-
assert isinstance(transformed_fn, partial)
1050-
return UnstableFn(
1051-
function(ruleset=get_current_ruleset(), use_body_as_name=True, subsume=True)(transformed_fn.func),
1052-
*transformed_fn.args,
1081+
decls = Declarations()
1082+
return_type, *arg_types = [resolve_type_annotation_mutate(decls, tp) for tp in get_type_args()]
1083+
arg_names = [p.name for p in signature(fn).parameters.values()]
1084+
arg_decls = [
1085+
TypedExprDecl(tp.to_just(), UnboundVarDecl(name)) for name, tp in zip(arg_names, arg_types, strict=True)
1086+
]
1087+
res = resolve_literal(
1088+
return_type, fn(*(RuntimeExpr.__from_values__(decls, a) for a in arg_decls)), Thunk.value(decls)
10531089
)
1090+
res_expr = res.__egg_typed_expr__
1091+
decls |= res
1092+
# these are all the args that appear in the body that are not bound by the args of the function
1093+
unbound_vars = list(collect_unbound_vars(res_expr) - set(arg_decls))
1094+
# prefix the args with them
1095+
fn_ref = UnnamedFunctionRef(tuple(unbound_vars + arg_decls), res_expr)
1096+
rewrite_decl = DefaultRewriteDecl(fn_ref, res_expr.expr, subsume=True)
1097+
ruleset_decls = _add_default_rewrite_inner(decls, rewrite_decl, get_current_ruleset())
1098+
ruleset_decls |= res
10541099

1055-
1056-
def value_to_annotation(a: object) -> type | None:
1057-
# only lift runtime expressions (which could contain vars) not any other nonlocals/globals we use in the function
1058-
if not isinstance(a, RuntimeExpr):
1059-
return None
1060-
return cast("type", RuntimeClass(Thunk.value(a.__egg_decls__), a.__egg_typed_expr__.tp.to_var()))
1100+
fn = RuntimeFunction(Thunk.value(decls), Thunk.value(fn_ref))
1101+
return UnstableFn(fn, *(RuntimeExpr.__from_values__(decls, v) for v in unbound_vars))
10611102

10621103

10631104
converter(FunctionType, UnstableFn, _convert_function)
10641105

1106+
##
1107+
# Utility Functions
1108+
##
1109+
10651110

10661111
def _extract_lit(e: BaseExpr) -> LitType:
10671112
"""

python/egglog/declarations.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
"UnboundVarDecl",
7878
"UnionDecl",
7979
"UnnamedFunctionRef",
80+
"collect_unbound_vars",
8081
"replace_typed_expr",
8182
"upcast_declerations",
8283
]
@@ -683,6 +684,28 @@ def _inner(typed_expr: TypedExprDecl) -> TypedExprDecl:
683684
return _inner(typed_expr)
684685

685686

687+
def collect_unbound_vars(typed_expr: TypedExprDecl) -> set[TypedExprDecl]:
688+
"""
689+
Returns the set of all unbound vars
690+
"""
691+
seen = set[TypedExprDecl]()
692+
unbound_vars = set[TypedExprDecl]()
693+
694+
def visit(typed_expr: TypedExprDecl) -> None:
695+
if typed_expr in seen:
696+
return
697+
seen.add(typed_expr)
698+
match typed_expr.expr:
699+
case CallDecl(_, args) | PartialCallDecl(CallDecl(_, args)):
700+
for arg in args:
701+
visit(arg)
702+
case UnboundVarDecl(_):
703+
unbound_vars.add(typed_expr)
704+
705+
visit(typed_expr)
706+
return unbound_vars
707+
708+
686709
##
687710
# Schedules
688711
##

python/egglog/egraph.py

Lines changed: 51 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from . import bindings
3131
from .conversion import *
32+
from .conversion import convert_to_same_type, resolve_literal
3233
from .declarations import *
3334
from .egraph_state import *
3435
from .ipython_magic import IN_IPYTHON
@@ -281,7 +282,6 @@ def function(
281282
mutates_first_arg: bool = ...,
282283
unextractable: bool = ...,
283284
ruleset: Ruleset | None = ...,
284-
use_body_as_name: bool = ...,
285285
subsume: bool = ...,
286286
) -> Callable[[CONSTRUCTOR_CALLABLE], CONSTRUCTOR_CALLABLE]: ...
287287

@@ -467,7 +467,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
467467
decls.set_function_decl(ref, decl)
468468
continue
469469
try:
470-
_, add_rewrite = _fn_decl(
470+
add_rewrite = _fn_decl(
471471
decls,
472472
egg_fn,
473473
ref,
@@ -505,19 +505,17 @@ class _FunctionConstructor:
505505
merge: Callable[[object, object], object] | None = None
506506
unextractable: bool = False
507507
ruleset: Ruleset | None = None
508-
use_body_as_name: bool = False
509508
subsume: bool = False
510509

511510
def __call__(self, fn: Callable) -> RuntimeFunction:
512511
return RuntimeFunction(*split_thunk(Thunk.fn(self.create_decls, fn)))
513512

514513
def create_decls(self, fn: Callable) -> tuple[Declarations, CallableRef]:
515514
decls = Declarations()
516-
ref = None if self.use_body_as_name else FunctionRef(fn.__name__)
517-
ref, add_rewrite = _fn_decl(
515+
add_rewrite = _fn_decl(
518516
decls,
519517
self.egg_fn,
520-
ref,
518+
ref := FunctionRef(fn.__name__),
521519
fn,
522520
self.hint_locals,
523521
self.cost,
@@ -535,8 +533,7 @@ def create_decls(self, fn: Callable) -> tuple[Declarations, CallableRef]:
535533
def _fn_decl(
536534
decls: Declarations,
537535
egg_name: str | None,
538-
# If ref is Callable, then generate the ref from the function name
539-
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef | None,
536+
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
540537
fn: object,
541538
# Pass in the locals, retrieved from the frame when wrapping,
542539
# so that we support classes and function defined inside of other functions (which won't show up in the globals)
@@ -549,7 +546,7 @@ def _fn_decl(
549546
ruleset: Ruleset | None = None,
550547
unextractable: bool = False,
551548
reverse_args: bool = False,
552-
) -> tuple[CallableRef, Callable[[], None]]:
549+
) -> Callable[[], None]:
553550
"""
554551
Sets the function decl for the function object and returns the ref as well as a thunk that sets the default callable.
555552
"""
@@ -619,50 +616,39 @@ def _fn_decl(
619616

620617
# defer this in generator so it doesn't resolve for builtins eagerly
621618
args = (TypedExprDecl(tp.to_just(), UnboundVarDecl(name)) for name, tp in zip(arg_names, arg_types, strict=True))
622-
res_ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef
623-
res_thunk: Callable[[], object]
624-
# If we were not passed in a ref, this is an unnamed funciton, so eagerly compute the value and use that to refer to it
625-
if not ref:
626-
tuple_args = tuple(args)
627-
res = _create_default_value(decls, ref, fn, tuple_args, ruleset)
628-
assert isinstance(res, RuntimeExpr)
629-
res_ref = UnnamedFunctionRef(tuple_args, res.__egg_typed_expr__)
630-
decls._unnamed_functions.add(res_ref)
631-
res_thunk = Thunk.value(res)
632619

620+
return_type_is_eqsort = (
621+
not decls._classes[return_type.name].builtin if isinstance(return_type, TypeRefWithVars) else False
622+
)
623+
is_constructor = not is_builtin and return_type_is_eqsort and merged is None
624+
signature_ = FunctionSignature(
625+
return_type=None if mutates_first_arg else return_type,
626+
var_arg_type=var_arg_type,
627+
arg_types=arg_types,
628+
arg_names=arg_names,
629+
arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
630+
reverse_args=reverse_args,
631+
)
632+
decl: ConstructorDecl | FunctionDecl
633+
if is_constructor:
634+
decl = ConstructorDecl(signature_, egg_name, cost, unextractable)
633635
else:
634-
return_type_is_eqsort = (
635-
not decls._classes[return_type.name].builtin if isinstance(return_type, TypeRefWithVars) else False
636-
)
637-
is_constructor = not is_builtin and return_type_is_eqsort and merged is None
638-
signature_ = FunctionSignature(
639-
return_type=None if mutates_first_arg else return_type,
640-
var_arg_type=var_arg_type,
641-
arg_types=arg_types,
642-
arg_names=arg_names,
643-
arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
644-
reverse_args=reverse_args,
636+
if cost is not None:
637+
msg = "Cost can only be set for constructors"
638+
raise ValueError(msg)
639+
if unextractable:
640+
msg = "Unextractable can only be set for constructors"
641+
raise ValueError(msg)
642+
decl = FunctionDecl(
643+
signature=signature_,
644+
egg_name=egg_name,
645+
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
646+
builtin=is_builtin,
645647
)
646-
decl: ConstructorDecl | FunctionDecl
647-
if is_constructor:
648-
decl = ConstructorDecl(signature_, egg_name, cost, unextractable)
649-
else:
650-
if cost is not None:
651-
msg = "Cost can only be set for constructors"
652-
raise ValueError(msg)
653-
if unextractable:
654-
msg = "Unextractable can only be set for constructors"
655-
raise ValueError(msg)
656-
decl = FunctionDecl(
657-
signature=signature_,
658-
egg_name=egg_name,
659-
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
660-
builtin=is_builtin,
661-
)
662-
res_ref = ref
663-
decls.set_function_decl(ref, decl)
664-
res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset, context=f"creating {ref}")
665-
return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk, subsume)
648+
decls.set_function_decl(ref, decl)
649+
return Thunk.fn(
650+
_add_default_rewrite_function, decls, ref, fn, args, ruleset, subsume, return_type, context=f"creating {ref}"
651+
)
666652

667653

668654
# Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
@@ -736,35 +722,24 @@ def _constant_thunk(
736722
return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref))
737723

738724

739-
def _create_default_value(
725+
def _add_default_rewrite_function(
740726
decls: Declarations,
741-
ref: CallableRef | None,
727+
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
742728
fn: Callable,
743729
args: Iterable[TypedExprDecl],
744730
ruleset: Ruleset | None,
745-
) -> object:
731+
subsume: bool,
732+
res_type: TypeOrVarRef,
733+
) -> None:
746734
args: list[object] = [RuntimeExpr.__from_values__(decls, a) for a in args]
747735

748736
# If this is a classmethod, add the class as the first arg
749737
if isinstance(ref, ClassMethodRef):
750738
tp = decls.get_paramaterized_class(ref.class_name)
751739
args.insert(0, RuntimeClass(Thunk.value(decls), tp))
752740
with set_current_ruleset(ruleset):
753-
return fn(*args)
754-
755-
756-
def _add_default_rewrite_function(
757-
decls: Declarations,
758-
ref: CallableRef,
759-
res_type: TypeOrVarRef,
760-
ruleset: Ruleset | None,
761-
value_thunk: Callable[[], object],
762-
subsume: bool,
763-
) -> None:
764-
"""
765-
Helper functions that resolves a value thunk to create the default value.
766-
"""
767-
_add_default_rewrite(decls, ref, res_type, value_thunk(), ruleset, subsume)
741+
res = fn(*args)
742+
_add_default_rewrite(decls, ref, res_type, res, ruleset, subsume)
768743

769744

770745
def _add_default_rewrite(
@@ -784,14 +759,21 @@ def _add_default_rewrite(
784759
return
785760
resolved_value = resolve_literal(type_ref, default_rewrite, Thunk.value(decls))
786761
rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr, subsume)
762+
ruleset_decls = _add_default_rewrite_inner(decls, rewrite_decl, ruleset)
763+
ruleset_decls |= resolved_value
764+
765+
766+
def _add_default_rewrite_inner(
767+
decls: Declarations, rewrite_decl: DefaultRewriteDecl, ruleset: Ruleset | None
768+
) -> Declarations:
787769
if ruleset:
788770
ruleset_decls = ruleset._current_egg_decls
789771
ruleset_decl = ruleset.__egg_ruleset__
790772
else:
791773
ruleset_decls = decls
792774
ruleset_decl = decls.default_ruleset
793775
ruleset_decl.rules.append(rewrite_decl)
794-
ruleset_decls |= resolved_value
776+
return ruleset_decls
795777

796778

797779
def _last_param_variable(params: list[Parameter]) -> bool:

0 commit comments

Comments
 (0)