29
29
30
30
from . import bindings
31
31
from .conversion import *
32
+ from .conversion import convert_to_same_type , resolve_literal
32
33
from .declarations import *
33
34
from .egraph_state import *
34
35
from .ipython_magic import IN_IPYTHON
@@ -281,7 +282,6 @@ def function(
281
282
mutates_first_arg : bool = ...,
282
283
unextractable : bool = ...,
283
284
ruleset : Ruleset | None = ...,
284
- use_body_as_name : bool = ...,
285
285
subsume : bool = ...,
286
286
) -> Callable [[CONSTRUCTOR_CALLABLE ], CONSTRUCTOR_CALLABLE ]: ...
287
287
@@ -467,7 +467,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
467
467
decls .set_function_decl (ref , decl )
468
468
continue
469
469
try :
470
- _ , add_rewrite = _fn_decl (
470
+ add_rewrite = _fn_decl (
471
471
decls ,
472
472
egg_fn ,
473
473
ref ,
@@ -505,19 +505,17 @@ class _FunctionConstructor:
505
505
merge : Callable [[object , object ], object ] | None = None
506
506
unextractable : bool = False
507
507
ruleset : Ruleset | None = None
508
- use_body_as_name : bool = False
509
508
subsume : bool = False
510
509
511
510
def __call__ (self , fn : Callable ) -> RuntimeFunction :
512
511
return RuntimeFunction (* split_thunk (Thunk .fn (self .create_decls , fn )))
513
512
514
513
def create_decls (self , fn : Callable ) -> tuple [Declarations , CallableRef ]:
515
514
decls = Declarations ()
516
- ref = None if self .use_body_as_name else FunctionRef (fn .__name__ )
517
- ref , add_rewrite = _fn_decl (
515
+ add_rewrite = _fn_decl (
518
516
decls ,
519
517
self .egg_fn ,
520
- ref ,
518
+ ref := FunctionRef ( fn . __name__ ) ,
521
519
fn ,
522
520
self .hint_locals ,
523
521
self .cost ,
@@ -535,8 +533,7 @@ def create_decls(self, fn: Callable) -> tuple[Declarations, CallableRef]:
535
533
def _fn_decl (
536
534
decls : Declarations ,
537
535
egg_name : str | None ,
538
- # If ref is Callable, then generate the ref from the function name
539
- ref : FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef | None ,
536
+ ref : FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef ,
540
537
fn : object ,
541
538
# Pass in the locals, retrieved from the frame when wrapping,
542
539
# so that we support classes and function defined inside of other functions (which won't show up in the globals)
@@ -549,7 +546,7 @@ def _fn_decl(
549
546
ruleset : Ruleset | None = None ,
550
547
unextractable : bool = False ,
551
548
reverse_args : bool = False ,
552
- ) -> tuple [ CallableRef , Callable [[], None ] ]:
549
+ ) -> Callable [[], None ]:
553
550
"""
554
551
Sets the function decl for the function object and returns the ref as well as a thunk that sets the default callable.
555
552
"""
@@ -619,50 +616,39 @@ def _fn_decl(
619
616
620
617
# defer this in generator so it doesn't resolve for builtins eagerly
621
618
args = (TypedExprDecl (tp .to_just (), UnboundVarDecl (name )) for name , tp in zip (arg_names , arg_types , strict = True ))
622
- res_ref : FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef | UnnamedFunctionRef
623
- res_thunk : Callable [[], object ]
624
- # If we were not passed in a ref, this is an unnamed funciton, so eagerly compute the value and use that to refer to it
625
- if not ref :
626
- tuple_args = tuple (args )
627
- res = _create_default_value (decls , ref , fn , tuple_args , ruleset )
628
- assert isinstance (res , RuntimeExpr )
629
- res_ref = UnnamedFunctionRef (tuple_args , res .__egg_typed_expr__ )
630
- decls ._unnamed_functions .add (res_ref )
631
- res_thunk = Thunk .value (res )
632
619
620
+ return_type_is_eqsort = (
621
+ not decls ._classes [return_type .name ].builtin if isinstance (return_type , TypeRefWithVars ) else False
622
+ )
623
+ is_constructor = not is_builtin and return_type_is_eqsort and merged is None
624
+ signature_ = FunctionSignature (
625
+ return_type = None if mutates_first_arg else return_type ,
626
+ var_arg_type = var_arg_type ,
627
+ arg_types = arg_types ,
628
+ arg_names = arg_names ,
629
+ arg_defaults = tuple (a .__egg_typed_expr__ .expr if a is not None else None for a in arg_defaults ),
630
+ reverse_args = reverse_args ,
631
+ )
632
+ decl : ConstructorDecl | FunctionDecl
633
+ if is_constructor :
634
+ decl = ConstructorDecl (signature_ , egg_name , cost , unextractable )
633
635
else :
634
- return_type_is_eqsort = (
635
- not decls . _classes [ return_type . name ]. builtin if isinstance ( return_type , TypeRefWithVars ) else False
636
- )
637
- is_constructor = not is_builtin and return_type_is_eqsort and merged is None
638
- signature_ = FunctionSignature (
639
- return_type = None if mutates_first_arg else return_type ,
640
- var_arg_type = var_arg_type ,
641
- arg_types = arg_types ,
642
- arg_names = arg_names ,
643
- arg_defaults = tuple ( a .__egg_typed_expr__ .expr if a is not None else None for a in arg_defaults ) ,
644
- reverse_args = reverse_args ,
636
+ if cost is not None :
637
+ msg = "Cost can only be set for constructors"
638
+ raise ValueError ( msg )
639
+ if unextractable :
640
+ msg = "Unextractable can only be set for constructors"
641
+ raise ValueError ( msg )
642
+ decl = FunctionDecl (
643
+ signature = signature_ ,
644
+ egg_name = egg_name ,
645
+ merge = merged .__egg_typed_expr__ .expr if merged is not None else None ,
646
+ builtin = is_builtin ,
645
647
)
646
- decl : ConstructorDecl | FunctionDecl
647
- if is_constructor :
648
- decl = ConstructorDecl (signature_ , egg_name , cost , unextractable )
649
- else :
650
- if cost is not None :
651
- msg = "Cost can only be set for constructors"
652
- raise ValueError (msg )
653
- if unextractable :
654
- msg = "Unextractable can only be set for constructors"
655
- raise ValueError (msg )
656
- decl = FunctionDecl (
657
- signature = signature_ ,
658
- egg_name = egg_name ,
659
- merge = merged .__egg_typed_expr__ .expr if merged is not None else None ,
660
- builtin = is_builtin ,
661
- )
662
- res_ref = ref
663
- decls .set_function_decl (ref , decl )
664
- res_thunk = Thunk .fn (_create_default_value , decls , ref , fn , args , ruleset , context = f"creating { ref } " )
665
- return res_ref , Thunk .fn (_add_default_rewrite_function , decls , res_ref , return_type , ruleset , res_thunk , subsume )
648
+ decls .set_function_decl (ref , decl )
649
+ return Thunk .fn (
650
+ _add_default_rewrite_function , decls , ref , fn , args , ruleset , subsume , return_type , context = f"creating { ref } "
651
+ )
666
652
667
653
668
654
# Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
@@ -736,35 +722,24 @@ def _constant_thunk(
736
722
return decls , TypedExprDecl (type_ref .to_just (), CallDecl (callable_ref ))
737
723
738
724
739
- def _create_default_value (
725
+ def _add_default_rewrite_function (
740
726
decls : Declarations ,
741
- ref : CallableRef | None ,
727
+ ref : FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef ,
742
728
fn : Callable ,
743
729
args : Iterable [TypedExprDecl ],
744
730
ruleset : Ruleset | None ,
745
- ) -> object :
731
+ subsume : bool ,
732
+ res_type : TypeOrVarRef ,
733
+ ) -> None :
746
734
args : list [object ] = [RuntimeExpr .__from_values__ (decls , a ) for a in args ]
747
735
748
736
# If this is a classmethod, add the class as the first arg
749
737
if isinstance (ref , ClassMethodRef ):
750
738
tp = decls .get_paramaterized_class (ref .class_name )
751
739
args .insert (0 , RuntimeClass (Thunk .value (decls ), tp ))
752
740
with set_current_ruleset (ruleset ):
753
- return fn (* args )
754
-
755
-
756
- def _add_default_rewrite_function (
757
- decls : Declarations ,
758
- ref : CallableRef ,
759
- res_type : TypeOrVarRef ,
760
- ruleset : Ruleset | None ,
761
- value_thunk : Callable [[], object ],
762
- subsume : bool ,
763
- ) -> None :
764
- """
765
- Helper functions that resolves a value thunk to create the default value.
766
- """
767
- _add_default_rewrite (decls , ref , res_type , value_thunk (), ruleset , subsume )
741
+ res = fn (* args )
742
+ _add_default_rewrite (decls , ref , res_type , res , ruleset , subsume )
768
743
769
744
770
745
def _add_default_rewrite (
@@ -784,14 +759,21 @@ def _add_default_rewrite(
784
759
return
785
760
resolved_value = resolve_literal (type_ref , default_rewrite , Thunk .value (decls ))
786
761
rewrite_decl = DefaultRewriteDecl (ref , resolved_value .__egg_typed_expr__ .expr , subsume )
762
+ ruleset_decls = _add_default_rewrite_inner (decls , rewrite_decl , ruleset )
763
+ ruleset_decls |= resolved_value
764
+
765
+
766
+ def _add_default_rewrite_inner (
767
+ decls : Declarations , rewrite_decl : DefaultRewriteDecl , ruleset : Ruleset | None
768
+ ) -> Declarations :
787
769
if ruleset :
788
770
ruleset_decls = ruleset ._current_egg_decls
789
771
ruleset_decl = ruleset .__egg_ruleset__
790
772
else :
791
773
ruleset_decls = decls
792
774
ruleset_decl = decls .default_ruleset
793
775
ruleset_decl .rules .append (rewrite_decl )
794
- ruleset_decls |= resolved_value
776
+ return ruleset_decls
795
777
796
778
797
779
def _last_param_variable (params : list [Parameter ]) -> bool :
0 commit comments