Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ _This project uses semantic versioning_
Also changes the representation to be an index into a list instead of the ID, making egglog programs more deterministic.
- Prefix constant declerations and unbound variables to not shadow let variables
- BREAKING: Remove `simplify` since it was removed upstream. You can manually replace it with an insert, run, then extract.
- Change how anonymous functions are converted to remove metaprogramming and lift only the unbound variables as args
- Add support for getting the "value" of a function type with `.eval()`, i.e. `assert UnstableFn(f).eval() == f`.

## 10.0.2 (2025-06-22)

Expand Down
101 changes: 73 additions & 28 deletions python/egglog/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@
from collections.abc import Callable
from fractions import Fraction
from functools import partial, reduce
from inspect import signature
from types import FunctionType, MethodType
from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, cast, overload

from typing_extensions import TypeVarTuple, Unpack

from .conversion import convert, converter, get_type_args
from egglog.declarations import TypedExprDecl

from .conversion import convert, converter, get_type_args, resolve_literal
from .declarations import *
from .egraph import BaseExpr, BuiltinExpr, expr_fact, function, get_current_ruleset, method
from .functionalize import functionalize
from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
from .egraph import BaseExpr, BuiltinExpr, _add_default_rewrite_inner, expr_fact, function, get_current_ruleset, method
from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction, resolve_type_annotation_mutate
from .thunk import Thunk

if TYPE_CHECKING:
Expand Down Expand Up @@ -1022,46 +1024,89 @@ def __init__(self, f, *partial) -> None: ...
@method(egg_fn="unstable-app")
def __call__(self, *args: Unpack[TS]) -> T: ...

@method(preserve=True)
def eval(self) -> Callable[[Unpack[TS]], T]:
"""
If this is a constructor, returns either the callable directly or a `functools.partial` function if args are provided.
"""
assert isinstance(self, RuntimeExpr)
match self.__egg_typed_expr__.expr:
case PartialCallDecl(CallDecl() as call):
fn, args = _deconstruct_call_decl(self.__egg_decls_thunk__, call)
if not args:
return fn
return partial(fn, *args)
msg = "UnstableFn can only be evaluated if it is a function or a partial application of a function."
raise BuiltinEvalError(msg)


def _deconstruct_call_decl(
decls_thunk: Callable[[], Declarations], call: CallDecl
) -> tuple[Callable, tuple[object, ...]]:
"""
Deconstructs a CallDecl into a runtime callable and its arguments.
"""
args = call.args
arg_exprs = tuple(RuntimeExpr(decls_thunk, Thunk.value(a)) for a in args)
egg_bound = (
JustTypeRef(call.callable.class_name, call.bound_tp_params or ())
if isinstance(call.callable, (ClassMethodRef, InitRef, ClassVariableRef))
else None
)
if isinstance(call.callable, InitRef):
return RuntimeClass(
decls_thunk,
TypeRefWithVars(
call.callable.class_name,
),
), arg_exprs
return RuntimeFunction(decls_thunk, Thunk.value(call.callable), egg_bound), arg_exprs


# Method Type is for builtins like __getitem__
converter(MethodType, UnstableFn, lambda m: UnstableFn(m.__func__, m.__self__))
converter(RuntimeFunction, UnstableFn, UnstableFn)
converter(partial, UnstableFn, lambda p: UnstableFn(p.func, *p.args))


def _convert_function(a: FunctionType) -> UnstableFn:
def _convert_function(fn: FunctionType) -> UnstableFn:
"""
Converts a function type to an unstable function
Converts a function type to an unstable function. This function will be an anon function in egglog.

Would just be UnstableFn(function(a)) but we have to look for any nonlocals and globals
which are runtime expressions with `var`s in them and add them as args to the function
Would just be UnstableFn(function(a)) but we have to account for unbound vars within the body.

This means that we have to turn all of those unbound vars into args to the function, and then
partially apply them, alongside creating a default rewrite for the function.
"""
# Update annotations of a to be the type we are trying to convert to
return_tp, *arg_tps = get_type_args()
a.__annotations__ = {
"return": return_tp,
# The first varnames should always be the arg names
**dict(zip(a.__code__.co_varnames, arg_tps, strict=False)),
}
# Modify name to make it unique
# a.__name__ = f"{a.__name__} {hash(a.__code__)}"
transformed_fn = functionalize(a, value_to_annotation)
assert isinstance(transformed_fn, partial)
return UnstableFn(
function(ruleset=get_current_ruleset(), use_body_as_name=True, subsume=True)(transformed_fn.func),
*transformed_fn.args,
decls = Declarations()
return_type, *arg_types = [resolve_type_annotation_mutate(decls, tp) for tp in get_type_args()]
arg_names = [p.name for p in signature(fn).parameters.values()]
arg_decls = [
TypedExprDecl(tp.to_just(), UnboundVarDecl(name)) for name, tp in zip(arg_names, arg_types, strict=True)
]
res = resolve_literal(
return_type, fn(*(RuntimeExpr.__from_values__(decls, a) for a in arg_decls)), Thunk.value(decls)
)
res_expr = res.__egg_typed_expr__
decls |= res
# these are all the args that appear in the body that are not bound by the args of the function
unbound_vars = list(collect_unbound_vars(res_expr) - set(arg_decls))
# prefix the args with them
fn_ref = UnnamedFunctionRef(tuple(unbound_vars + arg_decls), res_expr)
rewrite_decl = DefaultRewriteDecl(fn_ref, res_expr.expr, subsume=True)
ruleset_decls = _add_default_rewrite_inner(decls, rewrite_decl, get_current_ruleset())
ruleset_decls |= res


def value_to_annotation(a: object) -> type | None:
# only lift runtime expressions (which could contain vars) not any other nonlocals/globals we use in the function
if not isinstance(a, RuntimeExpr):
return None
return cast("type", RuntimeClass(Thunk.value(a.__egg_decls__), a.__egg_typed_expr__.tp.to_var()))
fn = RuntimeFunction(Thunk.value(decls), Thunk.value(fn_ref))
return UnstableFn(fn, *(RuntimeExpr.__from_values__(decls, v) for v in unbound_vars))


converter(FunctionType, UnstableFn, _convert_function)

##
# Utility Functions
##


def _extract_lit(e: BaseExpr) -> LitType:
"""
Expand Down
23 changes: 23 additions & 0 deletions python/egglog/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"UnboundVarDecl",
"UnionDecl",
"UnnamedFunctionRef",
"collect_unbound_vars",
"replace_typed_expr",
"upcast_declerations",
]
Expand Down Expand Up @@ -683,6 +684,28 @@ def _inner(typed_expr: TypedExprDecl) -> TypedExprDecl:
return _inner(typed_expr)


def collect_unbound_vars(typed_expr: TypedExprDecl) -> set[TypedExprDecl]:
"""
Returns the set of all unbound vars
"""
seen = set[TypedExprDecl]()
unbound_vars = set[TypedExprDecl]()

def visit(typed_expr: TypedExprDecl) -> None:
if typed_expr in seen:
return
seen.add(typed_expr)
match typed_expr.expr:
case CallDecl(_, args) | PartialCallDecl(CallDecl(_, args)):
for arg in args:
visit(arg)
case UnboundVarDecl(_):
unbound_vars.add(typed_expr)

visit(typed_expr)
return unbound_vars


##
# Schedules
##
Expand Down
120 changes: 51 additions & 69 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from . import bindings
from .conversion import *
from .conversion import convert_to_same_type, resolve_literal
from .declarations import *
from .egraph_state import *
from .ipython_magic import IN_IPYTHON
Expand Down Expand Up @@ -281,7 +282,6 @@ def function(
mutates_first_arg: bool = ...,
unextractable: bool = ...,
ruleset: Ruleset | None = ...,
use_body_as_name: bool = ...,
subsume: bool = ...,
) -> Callable[[CONSTRUCTOR_CALLABLE], CONSTRUCTOR_CALLABLE]: ...

Expand Down Expand Up @@ -467,7 +467,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
decls.set_function_decl(ref, decl)
continue
try:
_, add_rewrite = _fn_decl(
add_rewrite = _fn_decl(
decls,
egg_fn,
ref,
Expand Down Expand Up @@ -505,19 +505,17 @@ class _FunctionConstructor:
merge: Callable[[object, object], object] | None = None
unextractable: bool = False
ruleset: Ruleset | None = None
use_body_as_name: bool = False
subsume: bool = False

def __call__(self, fn: Callable) -> RuntimeFunction:
return RuntimeFunction(*split_thunk(Thunk.fn(self.create_decls, fn)))

def create_decls(self, fn: Callable) -> tuple[Declarations, CallableRef]:
decls = Declarations()
ref = None if self.use_body_as_name else FunctionRef(fn.__name__)
ref, add_rewrite = _fn_decl(
add_rewrite = _fn_decl(
decls,
self.egg_fn,
ref,
ref := FunctionRef(fn.__name__),
fn,
self.hint_locals,
self.cost,
Expand All @@ -535,8 +533,7 @@ def create_decls(self, fn: Callable) -> tuple[Declarations, CallableRef]:
def _fn_decl(
decls: Declarations,
egg_name: str | None,
# If ref is Callable, then generate the ref from the function name
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef | None,
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
fn: object,
# Pass in the locals, retrieved from the frame when wrapping,
# so that we support classes and function defined inside of other functions (which won't show up in the globals)
Expand All @@ -549,7 +546,7 @@ def _fn_decl(
ruleset: Ruleset | None = None,
unextractable: bool = False,
reverse_args: bool = False,
) -> tuple[CallableRef, Callable[[], None]]:
) -> Callable[[], None]:
"""
Sets the function decl for the function object and returns the ref as well as a thunk that sets the default callable.
"""
Expand Down Expand Up @@ -619,50 +616,39 @@ def _fn_decl(

# defer this in generator so it doesn't resolve for builtins eagerly
args = (TypedExprDecl(tp.to_just(), UnboundVarDecl(name)) for name, tp in zip(arg_names, arg_types, strict=True))
res_ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef
res_thunk: Callable[[], object]
# If we were not passed in a ref, this is an unnamed funciton, so eagerly compute the value and use that to refer to it
if not ref:
tuple_args = tuple(args)
res = _create_default_value(decls, ref, fn, tuple_args, ruleset)
assert isinstance(res, RuntimeExpr)
res_ref = UnnamedFunctionRef(tuple_args, res.__egg_typed_expr__)
decls._unnamed_functions.add(res_ref)
res_thunk = Thunk.value(res)

return_type_is_eqsort = (
not decls._classes[return_type.name].builtin if isinstance(return_type, TypeRefWithVars) else False
)
is_constructor = not is_builtin and return_type_is_eqsort and merged is None
signature_ = FunctionSignature(
return_type=None if mutates_first_arg else return_type,
var_arg_type=var_arg_type,
arg_types=arg_types,
arg_names=arg_names,
arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
reverse_args=reverse_args,
)
decl: ConstructorDecl | FunctionDecl
if is_constructor:
decl = ConstructorDecl(signature_, egg_name, cost, unextractable)
else:
return_type_is_eqsort = (
not decls._classes[return_type.name].builtin if isinstance(return_type, TypeRefWithVars) else False
)
is_constructor = not is_builtin and return_type_is_eqsort and merged is None
signature_ = FunctionSignature(
return_type=None if mutates_first_arg else return_type,
var_arg_type=var_arg_type,
arg_types=arg_types,
arg_names=arg_names,
arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
reverse_args=reverse_args,
if cost is not None:
msg = "Cost can only be set for constructors"
raise ValueError(msg)
if unextractable:
msg = "Unextractable can only be set for constructors"
raise ValueError(msg)
decl = FunctionDecl(
signature=signature_,
egg_name=egg_name,
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
builtin=is_builtin,
)
decl: ConstructorDecl | FunctionDecl
if is_constructor:
decl = ConstructorDecl(signature_, egg_name, cost, unextractable)
else:
if cost is not None:
msg = "Cost can only be set for constructors"
raise ValueError(msg)
if unextractable:
msg = "Unextractable can only be set for constructors"
raise ValueError(msg)
decl = FunctionDecl(
signature=signature_,
egg_name=egg_name,
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
builtin=is_builtin,
)
res_ref = ref
decls.set_function_decl(ref, decl)
res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset, context=f"creating {ref}")
return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk, subsume)
decls.set_function_decl(ref, decl)
return Thunk.fn(
_add_default_rewrite_function, decls, ref, fn, args, ruleset, subsume, return_type, context=f"creating {ref}"
)


# Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
Expand Down Expand Up @@ -736,35 +722,24 @@ def _constant_thunk(
return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref))


def _create_default_value(
def _add_default_rewrite_function(
decls: Declarations,
ref: CallableRef | None,
ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
fn: Callable,
args: Iterable[TypedExprDecl],
ruleset: Ruleset | None,
) -> object:
subsume: bool,
res_type: TypeOrVarRef,
) -> None:
args: list[object] = [RuntimeExpr.__from_values__(decls, a) for a in args]

# If this is a classmethod, add the class as the first arg
if isinstance(ref, ClassMethodRef):
tp = decls.get_paramaterized_class(ref.class_name)
args.insert(0, RuntimeClass(Thunk.value(decls), tp))
with set_current_ruleset(ruleset):
return fn(*args)


def _add_default_rewrite_function(
decls: Declarations,
ref: CallableRef,
res_type: TypeOrVarRef,
ruleset: Ruleset | None,
value_thunk: Callable[[], object],
subsume: bool,
) -> None:
"""
Helper functions that resolves a value thunk to create the default value.
"""
_add_default_rewrite(decls, ref, res_type, value_thunk(), ruleset, subsume)
res = fn(*args)
_add_default_rewrite(decls, ref, res_type, res, ruleset, subsume)


def _add_default_rewrite(
Expand All @@ -784,14 +759,21 @@ def _add_default_rewrite(
return
resolved_value = resolve_literal(type_ref, default_rewrite, Thunk.value(decls))
rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr, subsume)
ruleset_decls = _add_default_rewrite_inner(decls, rewrite_decl, ruleset)
ruleset_decls |= resolved_value


def _add_default_rewrite_inner(
decls: Declarations, rewrite_decl: DefaultRewriteDecl, ruleset: Ruleset | None
) -> Declarations:
if ruleset:
ruleset_decls = ruleset._current_egg_decls
ruleset_decl = ruleset.__egg_ruleset__
else:
ruleset_decls = decls
ruleset_decl = decls.default_ruleset
ruleset_decl.rules.append(rewrite_decl)
ruleset_decls |= resolved_value
return ruleset_decls


def _last_param_variable(params: list[Parameter]) -> bool:
Expand Down
Loading
Loading