Skip to content

Commit 5aa044f

Browse files
Merge pull request #314 from egraphs-good/reimplemention-tmp-fn
Simplify how anonymous functions are parsed
2 parents c1d8d81 + 00d1b65 commit 5aa044f

File tree

8 files changed

+176
-249
lines changed

8 files changed

+176
-249
lines changed

docs/changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ _This project uses semantic versioning_
99
Also changes the representation to be an index into a list instead of the ID, making egglog programs more deterministic.
1010
- Prefix constant declerations and unbound variables to not shadow let variables
1111
- BREAKING: Remove `simplify` since it was removed upstream. You can manually replace it with an insert, run, then extract.
12+
- Change how anonymous functions are converted to remove metaprogramming and lift only the unbound variables as args
13+
- Add support for getting the "value" of a function type with `.eval()`, i.e. `assert UnstableFn(f).eval() == f`.
1214

1315
## 10.0.2 (2025-06-22)
1416

python/egglog/builtins.py

Lines changed: 73 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,18 @@
88
from collections.abc import Callable
99
from fractions import Fraction
1010
from functools import partial, reduce
11+
from inspect import signature
1112
from types import FunctionType, MethodType
1213
from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, cast, overload
1314

1415
from typing_extensions import TypeVarTuple, Unpack
1516

16-
from .conversion import convert, converter, get_type_args
17+
from egglog.declarations import TypedExprDecl
18+
19+
from .conversion import convert, converter, get_type_args, resolve_literal
1720
from .declarations import *
18-
from .egraph import BaseExpr, BuiltinExpr, expr_fact, function, get_current_ruleset, method
19-
from .functionalize import functionalize
20-
from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
21+
from .egraph import BaseExpr, BuiltinExpr, _add_default_rewrite_inner, expr_fact, function, get_current_ruleset, method
22+
from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction, resolve_type_annotation_mutate
2123
from .thunk import Thunk
2224

2325
if TYPE_CHECKING:
@@ -1022,46 +1024,89 @@ def __init__(self, f, *partial) -> None: ...
10221024
@method(egg_fn="unstable-app")
10231025
def __call__(self, *args: Unpack[TS]) -> T: ...
10241026

1027+
@method(preserve=True)
1028+
def eval(self) -> Callable[[Unpack[TS]], T]:
1029+
"""
1030+
If this is a constructor, returns either the callable directly or a `functools.partial` function if args are provided.
1031+
"""
1032+
assert isinstance(self, RuntimeExpr)
1033+
match self.__egg_typed_expr__.expr:
1034+
case PartialCallDecl(CallDecl() as call):
1035+
fn, args = _deconstruct_call_decl(self.__egg_decls_thunk__, call)
1036+
if not args:
1037+
return fn
1038+
return partial(fn, *args)
1039+
msg = "UnstableFn can only be evaluated if it is a function or a partial application of a function."
1040+
raise BuiltinEvalError(msg)
1041+
1042+
1043+
def _deconstruct_call_decl(
1044+
decls_thunk: Callable[[], Declarations], call: CallDecl
1045+
) -> tuple[Callable, tuple[object, ...]]:
1046+
"""
1047+
Deconstructs a CallDecl into a runtime callable and its arguments.
1048+
"""
1049+
args = call.args
1050+
arg_exprs = tuple(RuntimeExpr(decls_thunk, Thunk.value(a)) for a in args)
1051+
egg_bound = (
1052+
JustTypeRef(call.callable.class_name, call.bound_tp_params or ())
1053+
if isinstance(call.callable, (ClassMethodRef, InitRef, ClassVariableRef))
1054+
else None
1055+
)
1056+
if isinstance(call.callable, InitRef):
1057+
return RuntimeClass(
1058+
decls_thunk,
1059+
TypeRefWithVars(
1060+
call.callable.class_name,
1061+
),
1062+
), arg_exprs
1063+
return RuntimeFunction(decls_thunk, Thunk.value(call.callable), egg_bound), arg_exprs
1064+
10251065

10261066
# Method Type is for builtins like __getitem__
10271067
converter(MethodType, UnstableFn, lambda m: UnstableFn(m.__func__, m.__self__))
10281068
converter(RuntimeFunction, UnstableFn, UnstableFn)
10291069
converter(partial, UnstableFn, lambda p: UnstableFn(p.func, *p.args))
10301070

10311071

1032-
def _convert_function(a: FunctionType) -> UnstableFn:
1072+
def _convert_function(fn: FunctionType) -> UnstableFn:
10331073
"""
1034-
Converts a function type to an unstable function
1074+
Converts a function type to an unstable function. This function will be an anon function in egglog.
10351075
1036-
Would just be UnstableFn(function(a)) but we have to look for any nonlocals and globals
1037-
which are runtime expressions with `var`s in them and add them as args to the function
1076+
Would just be UnstableFn(function(a)) but we have to account for unbound vars within the body.
1077+
1078+
This means that we have to turn all of those unbound vars into args to the function, and then
1079+
partially apply them, alongside creating a default rewrite for the function.
10381080
"""
1039-
# Update annotations of a to be the type we are trying to convert to
1040-
return_tp, *arg_tps = get_type_args()
1041-
a.__annotations__ = {
1042-
"return": return_tp,
1043-
# The first varnames should always be the arg names
1044-
**dict(zip(a.__code__.co_varnames, arg_tps, strict=False)),
1045-
}
1046-
# Modify name to make it unique
1047-
# a.__name__ = f"{a.__name__} {hash(a.__code__)}"
1048-
transformed_fn = functionalize(a, value_to_annotation)
1049-
assert isinstance(transformed_fn, partial)
1050-
return UnstableFn(
1051-
function(ruleset=get_current_ruleset(), use_body_as_name=True, subsume=True)(transformed_fn.func),
1052-
*transformed_fn.args,
1081+
decls = Declarations()
1082+
return_type, *arg_types = [resolve_type_annotation_mutate(decls, tp) for tp in get_type_args()]
1083+
arg_names = [p.name for p in signature(fn).parameters.values()]
1084+
arg_decls = [
1085+
TypedExprDecl(tp.to_just(), UnboundVarDecl(name)) for name, tp in zip(arg_names, arg_types, strict=True)
1086+
]
1087+
res = resolve_literal(
1088+
return_type, fn(*(RuntimeExpr.__from_values__(decls, a) for a in arg_decls)), Thunk.value(decls)
10531089
)
1090+
res_expr = res.__egg_typed_expr__
1091+
decls |= res
1092+
# these are all the args that appear in the body that are not bound by the args of the function
1093+
unbound_vars = list(collect_unbound_vars(res_expr) - set(arg_decls))
1094+
# prefix the args with them
1095+
fn_ref = UnnamedFunctionRef(tuple(unbound_vars + arg_decls), res_expr)
1096+
rewrite_decl = DefaultRewriteDecl(fn_ref, res_expr.expr, subsume=True)
1097+
ruleset_decls = _add_default_rewrite_inner(decls, rewrite_decl, get_current_ruleset())
1098+
ruleset_decls |= res
10541099

1055-
1056-
def value_to_annotation(a: object) -> type | None:
1057-
# only lift runtime expressions (which could contain vars) not any other nonlocals/globals we use in the function
1058-
if not isinstance(a, RuntimeExpr):
1059-
return None
1060-
return cast("type", RuntimeClass(Thunk.value(a.__egg_decls__), a.__egg_typed_expr__.tp.to_var()))
1100+
fn = RuntimeFunction(Thunk.value(decls), Thunk.value(fn_ref))
1101+
return UnstableFn(fn, *(RuntimeExpr.__from_values__(decls, v) for v in unbound_vars))
10611102

10621103

10631104
converter(FunctionType, UnstableFn, _convert_function)
10641105

1106+
##
1107+
# Utility Functions
1108+
##
1109+
10651110

10661111
def _extract_lit(e: BaseExpr) -> LitType:
10671112
"""

python/egglog/declarations.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
"UnboundVarDecl",
7878
"UnionDecl",
7979
"UnnamedFunctionRef",
80+
"collect_unbound_vars",
8081
"replace_typed_expr",
8182
"upcast_declerations",
8283
]
@@ -683,6 +684,28 @@ def _inner(typed_expr: TypedExprDecl) -> TypedExprDecl:
683684
return _inner(typed_expr)
684685

685686

687+
def collect_unbound_vars(typed_expr: TypedExprDecl) -> set[TypedExprDecl]:
688+
"""
689+
Returns the set of all unbound vars
690+
"""
691+
seen = set[TypedExprDecl]()
692+
unbound_vars = set[TypedExprDecl]()
693+
694+
def visit(typed_expr: TypedExprDecl) -> None:
695+
if typed_expr in seen:
696+
return
697+
seen.add(typed_expr)
698+
match typed_expr.expr:
699+
case CallDecl(_, args) | PartialCallDecl(CallDecl(_, args)):
700+
for arg in args:
701+
visit(arg)
702+
case UnboundVarDecl(_):
703+
unbound_vars.add(typed_expr)
704+
705+
visit(typed_expr)
706+
return unbound_vars
707+
708+
686709
##
687710
# Schedules
688711
##

python/egglog/egraph.py

Lines changed: 51 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from . import bindings
3131
from .conversion import *
32+
from .conversion import convert_to_same_type, resolve_literal
3233
from .declarations import *
3334
from .egraph_state import *
3435
from .ipython_magic import IN_IPYTHON
@@ -281,7 +282,6 @@ def function(
281282
mutates_first_arg: bool = ...,
282283
unextractable: bool = ...,
283284
ruleset: Ruleset | None = ...,
284-
use_body_as_name: bool = ...,
285285
subsume: bool = ...,
286286
) -> Callable[[CONSTRUCTOR_CALLABLE], CONSTRUCTOR_CALLABLE]: ...
287287

@@ -467,7 +467,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
467467
decls.set_function_decl(ref, decl)
468468
continue
469469
try:
470-
_, add_rewrite = _fn_decl(
470+
add_rewrite = _fn_decl(
471471
decls,
472472
egg_fn,
473473
ref,
@@ -505,19 +505,17 @@ class _FunctionConstructor:
505505
merge: Callable[[object, object], object] | None = None
506506
unextractable: bool = False
507507
ruleset: Ruleset | None = None
508-
use_body_as_name: bool = False
509508
subsume: bool = False
510509

511510
def __call__(self, fn: Callable) -> RuntimeFunction:
512511
return RuntimeFunction(*split_thunk(Thunk.fn(self.create_decls, fn)))
513512

514513
def create_decls(self, fn: Callable) -> tuple[Declarations, CallableRef]:
515514
decls = Declarations()
516-
ref = None if self.use_body_as_name else FunctionRef(fn.__name__)
517-
ref, add_rewrite = _fn_decl(
515+
add_rewrite = _fn_decl(
518516
decls,
519517
self.egg_fn,
520-
ref,
518+
ref := FunctionRef(fn.__name__),
521519
fn,
522520
self.hint_locals,
523521
self.cost,
@@ -535,8 +533,7 @@ def create_decls(self, fn: Callable) -> tuple[Declarations, CallableRef]:
535533
def _fn_decl(
536534
decls: Declarations,
537535
egg_name: str | None,
538-
# If ref is Callable, then generate the ref from the function name
539-
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef | None,
536+
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
540537
fn: object,
541538
# Pass in the locals, retrieved from the frame when wrapping,
542539
# so that we support classes and function defined inside of other functions (which won't show up in the globals)
@@ -549,7 +546,7 @@ def _fn_decl(
549546
ruleset: Ruleset | None = None,
550547
unextractable: bool = False,
551548
reverse_args: bool = False,
552-
) -> tuple[CallableRef, Callable[[], None]]:
549+
) -> Callable[[], None]:
553550
"""
554551
Sets the function decl for the function object and returns the ref as well as a thunk that sets the default callable.
555552
"""
@@ -619,50 +616,39 @@ def _fn_decl(
619616

620617
# defer this in generator so it doesn't resolve for builtins eagerly
621618
args = (TypedExprDecl(tp.to_just(), UnboundVarDecl(name)) for name, tp in zip(arg_names, arg_types, strict=True))
622-
res_ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef
623-
res_thunk: Callable[[], object]
624-
# If we were not passed in a ref, this is an unnamed funciton, so eagerly compute the value and use that to refer to it
625-
if not ref:
626-
tuple_args = tuple(args)
627-
res = _create_default_value(decls, ref, fn, tuple_args, ruleset)
628-
assert isinstance(res, RuntimeExpr)
629-
res_ref = UnnamedFunctionRef(tuple_args, res.__egg_typed_expr__)
630-
decls._unnamed_functions.add(res_ref)
631-
res_thunk = Thunk.value(res)
632619

620+
return_type_is_eqsort = (
621+
not decls._classes[return_type.name].builtin if isinstance(return_type, TypeRefWithVars) else False
622+
)
623+
is_constructor = not is_builtin and return_type_is_eqsort and merged is None
624+
signature_ = FunctionSignature(
625+
return_type=None if mutates_first_arg else return_type,
626+
var_arg_type=var_arg_type,
627+
arg_types=arg_types,
628+
arg_names=arg_names,
629+
arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
630+
reverse_args=reverse_args,
631+
)
632+
decl: ConstructorDecl | FunctionDecl
633+
if is_constructor:
634+
decl = ConstructorDecl(signature_, egg_name, cost, unextractable)
633635
else:
634-
return_type_is_eqsort = (
635-
not decls._classes[return_type.name].builtin if isinstance(return_type, TypeRefWithVars) else False
636-
)
637-
is_constructor = not is_builtin and return_type_is_eqsort and merged is None
638-
signature_ = FunctionSignature(
639-
return_type=None if mutates_first_arg else return_type,
640-
var_arg_type=var_arg_type,
641-
arg_types=arg_types,
642-
arg_names=arg_names,
643-
arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
644-
reverse_args=reverse_args,
636+
if cost is not None:
637+
msg = "Cost can only be set for constructors"
638+
raise ValueError(msg)
639+
if unextractable:
640+
msg = "Unextractable can only be set for constructors"
641+
raise ValueError(msg)
642+
decl = FunctionDecl(
643+
signature=signature_,
644+
egg_name=egg_name,
645+
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
646+
builtin=is_builtin,
645647
)
646-
decl: ConstructorDecl | FunctionDecl
647-
if is_constructor:
648-
decl = ConstructorDecl(signature_, egg_name, cost, unextractable)
649-
else:
650-
if cost is not None:
651-
msg = "Cost can only be set for constructors"
652-
raise ValueError(msg)
653-
if unextractable:
654-
msg = "Unextractable can only be set for constructors"
655-
raise ValueError(msg)
656-
decl = FunctionDecl(
657-
signature=signature_,
658-
egg_name=egg_name,
659-
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
660-
builtin=is_builtin,
661-
)
662-
res_ref = ref
663-
decls.set_function_decl(ref, decl)
664-
res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset, context=f"creating {ref}")
665-
return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk, subsume)
648+
decls.set_function_decl(ref, decl)
649+
return Thunk.fn(
650+
_add_default_rewrite_function, decls, ref, fn, args, ruleset, subsume, return_type, context=f"creating {ref}"
651+
)
666652

667653

668654
# Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
@@ -736,35 +722,24 @@ def _constant_thunk(
736722
return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref))
737723

738724

739-
def _create_default_value(
725+
def _add_default_rewrite_function(
740726
decls: Declarations,
741-
ref: CallableRef | None,
727+
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
742728
fn: Callable,
743729
args: Iterable[TypedExprDecl],
744730
ruleset: Ruleset | None,
745-
) -> object:
731+
subsume: bool,
732+
res_type: TypeOrVarRef,
733+
) -> None:
746734
args: list[object] = [RuntimeExpr.__from_values__(decls, a) for a in args]
747735

748736
# If this is a classmethod, add the class as the first arg
749737
if isinstance(ref, ClassMethodRef):
750738
tp = decls.get_paramaterized_class(ref.class_name)
751739
args.insert(0, RuntimeClass(Thunk.value(decls), tp))
752740
with set_current_ruleset(ruleset):
753-
return fn(*args)
754-
755-
756-
def _add_default_rewrite_function(
757-
decls: Declarations,
758-
ref: CallableRef,
759-
res_type: TypeOrVarRef,
760-
ruleset: Ruleset | None,
761-
value_thunk: Callable[[], object],
762-
subsume: bool,
763-
) -> None:
764-
"""
765-
Helper functions that resolves a value thunk to create the default value.
766-
"""
767-
_add_default_rewrite(decls, ref, res_type, value_thunk(), ruleset, subsume)
741+
res = fn(*args)
742+
_add_default_rewrite(decls, ref, res_type, res, ruleset, subsume)
768743

769744

770745
def _add_default_rewrite(
@@ -784,14 +759,21 @@ def _add_default_rewrite(
784759
return
785760
resolved_value = resolve_literal(type_ref, default_rewrite, Thunk.value(decls))
786761
rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr, subsume)
762+
ruleset_decls = _add_default_rewrite_inner(decls, rewrite_decl, ruleset)
763+
ruleset_decls |= resolved_value
764+
765+
766+
def _add_default_rewrite_inner(
767+
decls: Declarations, rewrite_decl: DefaultRewriteDecl, ruleset: Ruleset | None
768+
) -> Declarations:
787769
if ruleset:
788770
ruleset_decls = ruleset._current_egg_decls
789771
ruleset_decl = ruleset.__egg_ruleset__
790772
else:
791773
ruleset_decls = decls
792774
ruleset_decl = decls.default_ruleset
793775
ruleset_decl.rules.append(rewrite_decl)
794-
ruleset_decls |= resolved_value
776+
return ruleset_decls
795777

796778

797779
def _last_param_variable(params: list[Parameter]) -> bool:

0 commit comments

Comments
 (0)