2929
3030from . import bindings
3131from .conversion import *
32+ from .conversion import convert_to_same_type , resolve_literal
3233from .declarations import *
3334from .egraph_state import *
3435from .ipython_magic import IN_IPYTHON
@@ -281,7 +282,6 @@ def function(
281282 mutates_first_arg : bool = ...,
282283 unextractable : bool = ...,
283284 ruleset : Ruleset | None = ...,
284- use_body_as_name : bool = ...,
285285 subsume : bool = ...,
286286) -> Callable [[CONSTRUCTOR_CALLABLE ], CONSTRUCTOR_CALLABLE ]: ...
287287
@@ -467,7 +467,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
467467 decls .set_function_decl (ref , decl )
468468 continue
469469 try :
470- _ , add_rewrite = _fn_decl (
470+ add_rewrite = _fn_decl (
471471 decls ,
472472 egg_fn ,
473473 ref ,
@@ -505,19 +505,17 @@ class _FunctionConstructor:
505505 merge : Callable [[object , object ], object ] | None = None
506506 unextractable : bool = False
507507 ruleset : Ruleset | None = None
508- use_body_as_name : bool = False
509508 subsume : bool = False
510509
511510 def __call__ (self , fn : Callable ) -> RuntimeFunction :
512511 return RuntimeFunction (* split_thunk (Thunk .fn (self .create_decls , fn )))
513512
514513 def create_decls (self , fn : Callable ) -> tuple [Declarations , CallableRef ]:
515514 decls = Declarations ()
516- ref = None if self .use_body_as_name else FunctionRef (fn .__name__ )
517- ref , add_rewrite = _fn_decl (
515+ add_rewrite = _fn_decl (
518516 decls ,
519517 self .egg_fn ,
520- ref ,
518+ ref := FunctionRef ( fn . __name__ ) ,
521519 fn ,
522520 self .hint_locals ,
523521 self .cost ,
@@ -535,8 +533,7 @@ def create_decls(self, fn: Callable) -> tuple[Declarations, CallableRef]:
535533def _fn_decl (
536534 decls : Declarations ,
537535 egg_name : str | None ,
538- # If ref is Callable, then generate the ref from the function name
539- ref : FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef | None ,
536+ ref : FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef ,
540537 fn : object ,
541538 # Pass in the locals, retrieved from the frame when wrapping,
542539 # so that we support classes and function defined inside of other functions (which won't show up in the globals)
@@ -549,7 +546,7 @@ def _fn_decl(
549546 ruleset : Ruleset | None = None ,
550547 unextractable : bool = False ,
551548 reverse_args : bool = False ,
552- ) -> tuple [ CallableRef , Callable [[], None ] ]:
549+ ) -> Callable [[], None ]:
553550 """
554551 Sets the function decl for the function object and returns the ref as well as a thunk that sets the default callable.
555552 """
@@ -619,50 +616,39 @@ def _fn_decl(
619616
620617 # defer this in generator so it doesn't resolve for builtins eagerly
621618 args = (TypedExprDecl (tp .to_just (), UnboundVarDecl (name )) for name , tp in zip (arg_names , arg_types , strict = True ))
622- res_ref : FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef
623- res_thunk : Callable [[], object ]
624- # If we were not passed in a ref, this is an unnamed funciton, so eagerly compute the value and use that to refer to it
625- if not ref :
626- tuple_args = tuple (args )
627- res = _create_default_value (decls , ref , fn , tuple_args , ruleset )
628- assert isinstance (res , RuntimeExpr )
629- res_ref = UnnamedFunctionRef (tuple_args , res .__egg_typed_expr__ )
630- decls ._unnamed_functions .add (res_ref )
631- res_thunk = Thunk .value (res )
632619
620+ return_type_is_eqsort = (
621+ not decls ._classes [return_type .name ].builtin if isinstance (return_type , TypeRefWithVars ) else False
622+ )
623+ is_constructor = not is_builtin and return_type_is_eqsort and merged is None
624+ signature_ = FunctionSignature (
625+ return_type = None if mutates_first_arg else return_type ,
626+ var_arg_type = var_arg_type ,
627+ arg_types = arg_types ,
628+ arg_names = arg_names ,
629+ arg_defaults = tuple (a .__egg_typed_expr__ .expr if a is not None else None for a in arg_defaults ),
630+ reverse_args = reverse_args ,
631+ )
632+ decl : ConstructorDecl | FunctionDecl
633+ if is_constructor :
634+ decl = ConstructorDecl (signature_ , egg_name , cost , unextractable )
633635 else :
634- return_type_is_eqsort = (
635- not decls . _classes [ return_type . name ]. builtin if isinstance ( return_type , TypeRefWithVars ) else False
636- )
637- is_constructor = not is_builtin and return_type_is_eqsort and merged is None
638- signature_ = FunctionSignature (
639- return_type = None if mutates_first_arg else return_type ,
640- var_arg_type = var_arg_type ,
641- arg_types = arg_types ,
642- arg_names = arg_names ,
643- arg_defaults = tuple ( a .__egg_typed_expr__ .expr if a is not None else None for a in arg_defaults ) ,
644- reverse_args = reverse_args ,
636+ if cost is not None :
637+ msg = "Cost can only be set for constructors"
638+ raise ValueError ( msg )
639+ if unextractable :
640+ msg = "Unextractable can only be set for constructors"
641+ raise ValueError ( msg )
642+ decl = FunctionDecl (
643+ signature = signature_ ,
644+ egg_name = egg_name ,
645+ merge = merged .__egg_typed_expr__ .expr if merged is not None else None ,
646+ builtin = is_builtin ,
645647 )
646- decl : ConstructorDecl | FunctionDecl
647- if is_constructor :
648- decl = ConstructorDecl (signature_ , egg_name , cost , unextractable )
649- else :
650- if cost is not None :
651- msg = "Cost can only be set for constructors"
652- raise ValueError (msg )
653- if unextractable :
654- msg = "Unextractable can only be set for constructors"
655- raise ValueError (msg )
656- decl = FunctionDecl (
657- signature = signature_ ,
658- egg_name = egg_name ,
659- merge = merged .__egg_typed_expr__ .expr if merged is not None else None ,
660- builtin = is_builtin ,
661- )
662- res_ref = ref
663- decls .set_function_decl (ref , decl )
664- res_thunk = Thunk .fn (_create_default_value , decls , ref , fn , args , ruleset , context = f"creating { ref } " )
665- return res_ref , Thunk .fn (_add_default_rewrite_function , decls , res_ref , return_type , ruleset , res_thunk , subsume )
648+ decls .set_function_decl (ref , decl )
649+ return Thunk .fn (
650+ _add_default_rewrite_function , decls , ref , fn , args , ruleset , subsume , return_type , context = f"creating { ref } "
651+ )
666652
667653
668654# Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
@@ -736,35 +722,24 @@ def _constant_thunk(
736722 return decls , TypedExprDecl (type_ref .to_just (), CallDecl (callable_ref ))
737723
738724
739- def _create_default_value (
725+ def _add_default_rewrite_function (
740726 decls : Declarations ,
741- ref : CallableRef | None ,
727+ ref : FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef ,
742728 fn : Callable ,
743729 args : Iterable [TypedExprDecl ],
744730 ruleset : Ruleset | None ,
745- ) -> object :
731+ subsume : bool ,
732+ res_type : TypeOrVarRef ,
733+ ) -> None :
746734 args : list [object ] = [RuntimeExpr .__from_values__ (decls , a ) for a in args ]
747735
748736 # If this is a classmethod, add the class as the first arg
749737 if isinstance (ref , ClassMethodRef ):
750738 tp = decls .get_paramaterized_class (ref .class_name )
751739 args .insert (0 , RuntimeClass (Thunk .value (decls ), tp ))
752740 with set_current_ruleset (ruleset ):
753- return fn (* args )
754-
755-
756- def _add_default_rewrite_function (
757- decls : Declarations ,
758- ref : CallableRef ,
759- res_type : TypeOrVarRef ,
760- ruleset : Ruleset | None ,
761- value_thunk : Callable [[], object ],
762- subsume : bool ,
763- ) -> None :
764- """
765- Helper functions that resolves a value thunk to create the default value.
766- """
767- _add_default_rewrite (decls , ref , res_type , value_thunk (), ruleset , subsume )
741+ res = fn (* args )
742+ _add_default_rewrite (decls , ref , res_type , res , ruleset , subsume )
768743
769744
770745def _add_default_rewrite (
@@ -784,14 +759,21 @@ def _add_default_rewrite(
784759 return
785760 resolved_value = resolve_literal (type_ref , default_rewrite , Thunk .value (decls ))
786761 rewrite_decl = DefaultRewriteDecl (ref , resolved_value .__egg_typed_expr__ .expr , subsume )
762+ ruleset_decls = _add_default_rewrite_inner (decls , rewrite_decl , ruleset )
763+ ruleset_decls |= resolved_value
764+
765+
766+ def _add_default_rewrite_inner (
767+ decls : Declarations , rewrite_decl : DefaultRewriteDecl , ruleset : Ruleset | None
768+ ) -> Declarations :
787769 if ruleset :
788770 ruleset_decls = ruleset ._current_egg_decls
789771 ruleset_decl = ruleset .__egg_ruleset__
790772 else :
791773 ruleset_decls = decls
792774 ruleset_decl = decls .default_ruleset
793775 ruleset_decl .rules .append (rewrite_decl )
794- ruleset_decls |= resolved_value
776+ return ruleset_decls
795777
796778
797779def _last_param_variable (params : list [Parameter ]) -> bool :
0 commit comments