Skip to content

Commit 25f2312

Browse files
Move unnamed functions to separate structure
1 parent 30f0f0d commit 25f2312

File tree

4 files changed

+87
-75
lines changed

4 files changed

+87
-75
lines changed

python/egglog/declarations.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[D
120120

121121
@dataclass
122122
class Declarations:
123-
# TODO: Replace with set of unnamed function decls
124-
_functions: dict[str | UnnamedFunctionRef, FunctionDecl | RelationDecl] = field(default_factory=dict)
123+
_unnamed_functions: set[UnnamedFunctionRef] = field(default_factory=set)
124+
_functions: dict[str, FunctionDecl | RelationDecl] = field(default_factory=dict)
125125
_constants: dict[str, ConstantDecl] = field(default_factory=dict)
126126
_classes: dict[str, ClassDecl] = field(default_factory=dict)
127127
_rulesets: dict[str, RulesetDecl | CombinedRulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
@@ -197,6 +197,8 @@ def get_callable_decl(self, ref: CallableRef) -> CallableDecl: # noqa: PLR0911
197197
init_fn = self._classes[class_name].init
198198
assert init_fn
199199
return init_fn
200+
case UnnamedFunctionRef():
201+
return ref.to_function_decl()
200202
assert_never(ref)
201203

202204
def set_function_decl(
@@ -329,20 +331,34 @@ class UnnamedFunctionRef:
329331
A reference to a function that doesn't have a name, but does have a body.
330332
"""
331333

332-
arg_types: tuple[JustTypeRef, ...]
333-
arg_names: tuple[str, ...]
334+
# tuple of var arg names and their types
335+
args: tuple[TypedExprDecl, ...]
334336
res: TypedExprDecl
335337

336-
@property
337-
def args(self) -> tuple[TypedExprDecl, ...]:
338-
return tuple(
339-
TypedExprDecl(tp, VarDecl(name, False)) for tp, name in zip(self.arg_types, self.arg_names, strict=True)
338+
def to_function_decl(self) -> FunctionDecl:
339+
arg_types = []
340+
arg_names = []
341+
for a in self.args:
342+
arg_types.append(a.tp.to_var())
343+
assert isinstance(a.expr, VarDecl)
344+
arg_names.append(a.expr.name)
345+
return FunctionDecl(
346+
FunctionSignature(
347+
arg_types=tuple(arg_types),
348+
arg_names=tuple(arg_names),
349+
arg_defaults=(None,) * len(self.args),
350+
return_type=self.res.tp.to_var(),
351+
),
340352
)
341353

354+
@property
355+
def egg_name(self) -> None | str:
356+
return None
357+
342358

343359
@dataclass(frozen=True)
344360
class FunctionRef:
345-
name: str | UnnamedFunctionRef
361+
name: str
346362

347363

348364
@dataclass(frozen=True)
@@ -380,7 +396,14 @@ class PropertyRef:
380396

381397

382398
CallableRef: TypeAlias = (
383-
FunctionRef | ConstantRef | MethodRef | ClassMethodRef | InitRef | ClassVariableRef | PropertyRef
399+
FunctionRef
400+
| ConstantRef
401+
| MethodRef
402+
| ClassMethodRef
403+
| InitRef
404+
| ClassVariableRef
405+
| PropertyRef
406+
| UnnamedFunctionRef
384407
)
385408

386409

python/egglog/egraph.py

Lines changed: 39 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ def _fn_decl(
694694
is_builtin: bool,
695695
ruleset: Ruleset | None = None,
696696
unextractable: bool = False,
697-
) -> tuple[FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef, Callable[[], None]]:
697+
) -> tuple[CallableRef, Callable[[], None]]:
698698
"""
699699
Sets the function decl for the function object and returns the ref as well as a thunk that sets the default callable.
700700
"""
@@ -707,10 +707,7 @@ def _fn_decl(
707707
if "Callable" not in hint_globals:
708708
hint_globals["Callable"] = Callable
709709

710-
try:
711-
hints = get_type_hints(fn, hint_globals, hint_locals)
712-
except Exception as e:
713-
raise TypeError(f"Failed to get type hints for {fn}") from e
710+
hints = get_type_hints(fn, hint_globals, hint_locals)
714711

715712
params = list(signature(fn).parameters.values())
716713

@@ -771,33 +768,41 @@ def _fn_decl(
771768
)
772769
)
773770
decls.update(*merge_action)
774-
775-
signature_ = FunctionSignature(
776-
return_type=None if mutates_first_arg else return_type,
777-
var_arg_type=var_arg_type,
778-
arg_types=arg_types,
779-
arg_names=arg_names,
780-
arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
781-
)
782-
decl = FunctionDecl(
783-
signature=signature_,
784-
cost=cost,
785-
egg_name=egg_name,
786-
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
787-
unextractable=unextractable,
788-
builtin=is_builtin,
789-
default=None if default is None else default.__egg_typed_expr__.expr,
790-
on_merge=tuple(a.action for a in merge_action),
791-
)
792-
res = Thunk.fn(_create_default_value, decls, ref, fn, signature_, ruleset)
771+
# defer this in generator so it doesnt resolve for builtins eagerly
772+
args = (TypedExprDecl(tp.to_just(), VarDecl(name, False)) for name, tp in zip(arg_names, arg_types, strict=True))
773+
res_ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef
774+
res_thunk: Callable[[], object]
793775
# If we were not passed in a ref, this is an unnamed funciton, so eagerly compute the value and use that to refer to it
794776
if not ref:
795-
res_value = res()
796-
assert isinstance(res_value, RuntimeExpr)
797-
just_arg_types = tuple(tp.to_just() for tp in arg_types)
798-
ref = FunctionRef(UnnamedFunctionRef(just_arg_types, arg_names, res_value.__egg_typed_expr__))
799-
decls.set_function_decl(ref, decl)
800-
return ref, Thunk.fn(_add_default_rewrite_function, decls, ref, signature_.semantic_return_type, ruleset, res)
777+
tuple_args = tuple(args)
778+
res = _create_default_value(decls, ref, fn, tuple_args, ruleset)
779+
assert isinstance(res, RuntimeExpr)
780+
res_ref = UnnamedFunctionRef(tuple_args, res.__egg_typed_expr__)
781+
decls._unnamed_functions.add(res_ref)
782+
res_thunk = Thunk.value(res)
783+
784+
else:
785+
signature_ = FunctionSignature(
786+
return_type=None if mutates_first_arg else return_type,
787+
var_arg_type=var_arg_type,
788+
arg_types=arg_types,
789+
arg_names=arg_names,
790+
arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
791+
)
792+
decl = FunctionDecl(
793+
signature=signature_,
794+
cost=cost,
795+
egg_name=egg_name,
796+
merge=merged.__egg_typed_expr__.expr if merged is not None else None,
797+
unextractable=unextractable,
798+
builtin=is_builtin,
799+
default=None if default is None else default.__egg_typed_expr__.expr,
800+
on_merge=tuple(a.action for a in merge_action),
801+
)
802+
res_ref = ref
803+
decls.set_function_decl(ref, decl)
804+
res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset)
805+
return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk)
801806

802807

803808
# Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
@@ -872,38 +877,25 @@ def _create_default_value(
872877
decls: Declarations,
873878
ref: CallableRef | None,
874879
fn: Callable,
875-
signature: FunctionSignature,
880+
args: Iterable[TypedExprDecl],
876881
ruleset: Ruleset | None,
877882
) -> object:
878-
args: list[object] = [
879-
RuntimeExpr.__from_values__(
880-
decls,
881-
TypedExprDecl(
882-
tp.to_just(),
883-
VarDecl(name, False),
884-
),
885-
)
886-
for name, tp in zip(signature.arg_names, signature.arg_types, strict=False)
887-
]
883+
args: list[object] = [RuntimeExpr.__from_values__(decls, a) for a in args]
888884

889885
# If this is a classmethod, add the class as the first arg
890886
if isinstance(ref, ClassMethodRef):
891887
tp = decls.get_paramaterized_class(ref.class_name)
892888
args.insert(0, RuntimeClass(Thunk.value(decls), tp))
893889
with set_current_ruleset(ruleset):
894-
try:
895-
return fn(*args)
896-
except Exception as err:
897-
msg = f"Error when calling {fn}"
898-
raise ValueError(msg) from err
890+
return fn(*args)
899891

900892

901893
def _add_default_rewrite_function(
902894
decls: Declarations,
903895
ref: CallableRef,
904896
res_type: TypeOrVarRef,
905897
ruleset: Ruleset | None,
906-
value_thunk: Thunk,
898+
value_thunk: Callable[[], object],
907899
) -> None:
908900
"""
909901
Helper functions that resolves a value thunk to create the default value.

python/egglog/egraph_state.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -401,16 +401,8 @@ def _generate_callable_egg_name(self, ref: CallableRef) -> str:
401401
"""
402402
match ref:
403403
case FunctionRef(name):
404-
match name:
405-
case str(name):
406-
return name
407-
case UnnamedFunctionRef(arg_types, arg_names, val):
408-
parts = (
409-
list(arg_names)
410-
+ [self.type_ref_to_egg(tp) for tp in arg_types]
411-
+ [str(self.typed_expr_to_egg(val, False))]
412-
)
413-
return "_".join(parts)
404+
return name
405+
414406
case ConstantRef(name):
415407
return name
416408
case (
@@ -422,6 +414,11 @@ def _generate_callable_egg_name(self, ref: CallableRef) -> str:
422414
return f"{cls_name}.{name}"
423415
case InitRef(cls_name):
424416
return f"{cls_name}.__init__"
417+
case UnnamedFunctionRef(args, val):
418+
parts = [str(self._expr_to_egg(a.expr)) + "-" + str(self.type_ref_to_egg(a.tp)) for a in args] + [
419+
str(self.typed_expr_to_egg(val, False))
420+
]
421+
return "_".join(parts)
425422
case _:
426423
assert_never(ref)
427424

python/egglog/pretty.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C90
179179
self(de)
180180
case CallDecl(ref, exprs, _):
181181
match ref:
182-
case FunctionRef(UnnamedFunctionRef(_, _, res)):
182+
case FunctionRef(UnnamedFunctionRef(_, res)):
183183
self(res.expr)
184184
case _:
185185
for e in exprs:
@@ -427,6 +427,8 @@ def _call_inner( # noqa: PLR0911
427427
case InitRef(class_name):
428428
tp_ref = JustTypeRef(class_name, bound_tp_params or ())
429429
return str(tp_ref), args
430+
case UnnamedFunctionRef():
431+
return ref, args
430432
assert_never(ref)
431433

432434
def _generate_name(self, typ: str) -> str:
@@ -453,9 +455,9 @@ def _pretty_partial(self, ref: CallableRef, args: list[ExprDecl]) -> str:
453455
"""
454456
match ref:
455457
case FunctionRef(name):
456-
if not isinstance(name, str):
457-
return self._pretty_function_body(name, args)
458458
fn = name
459+
case UnnamedFunctionRef():
460+
return self._pretty_function_body(ref, args)
459461
case (
460462
ClassMethodRef(class_name, method_name)
461463
| MethodRef(class_name, method_name)
@@ -484,16 +486,14 @@ def _pretty_function_body(self, fn: UnnamedFunctionRef, args: list[ExprDecl]) ->
484486
"""
485487
Pretty print the body of a function, partially applying some arguments.
486488
"""
487-
var_args = [
488-
TypedExprDecl(tp, VarDecl(name, False)) for tp, name in zip(fn.arg_types, fn.arg_names, strict=True)
489-
]
489+
var_args = fn.args
490490
replacements = {var_arg: TypedExprDecl(var_arg.tp, arg) for var_arg, arg in zip(var_args, args, strict=False)}
491491
var_args = var_args[len(args) :]
492492
res = replace_typed_expr(fn.res, replacements)
493-
arg_names = fn.arg_names[len(args) :]
493+
arg_names = fn.args[len(args) :]
494494
prefix = "lambda"
495495
if arg_names:
496-
prefix += f" {', '.join(arg_names)}"
496+
prefix += f" {', '.join(self(a.expr) for a in arg_names)}"
497497
return f"{prefix}: {self(res.expr)}"
498498

499499

0 commit comments

Comments
 (0)