@@ -269,7 +269,7 @@ def method(
269269 unextractable : bool = False ,
270270 ) -> Callable [[Callable [P , EXPR ]], Callable [P , EXPR ]]:
271271 return lambda fn : _WrappedMethod (
272- egg_fn , cost , default , merge , on_merge , fn , preserve , mutates_self , unextractable
272+ egg_fn , cost , default , merge , on_merge , fn , preserve , mutates_self , unextractable , False
273273 )
274274
275275 @overload
@@ -404,6 +404,7 @@ def method(
404404 on_merge : Callable [[Any , Any ], Iterable [ActionLike ]] | None = None ,
405405 mutates_self : bool = False ,
406406 unextractable : bool = False ,
407+ subsume : bool = False ,
407408) -> Callable [[CALLABLE ], CALLABLE ]: ...
408409
409410
@@ -417,6 +418,7 @@ def method(
417418 on_merge : Callable [[EXPR , EXPR ], Iterable [ActionLike ]] | None = None ,
418419 mutates_self : bool = False ,
419420 unextractable : bool = False ,
421+ subsume : bool = False ,
420422) -> Callable [[Callable [P , EXPR ]], Callable [P , EXPR ]]: ...
421423
422424
@@ -430,11 +432,14 @@ def method(
430432 preserve : bool = False ,
431433 mutates_self : bool = False ,
432434 unextractable : bool = False ,
435+ subsume : bool = False ,
433436) -> Callable [[Callable [P , EXPR ]], Callable [P , EXPR ]]:
434437 """
435438 Any method can be decorated with this to customize it's behavior. This is only supported in classes which subclass :class:`Expr`.
436439 """
437- return lambda fn : _WrappedMethod (egg_fn , cost , default , merge , on_merge , fn , preserve , mutates_self , unextractable )
440+ return lambda fn : _WrappedMethod (
441+ egg_fn , cost , default , merge , on_merge , fn , preserve , mutates_self , unextractable , subsume
442+ )
438443
439444
440445class _ExprMetaclass (type ):
@@ -519,7 +524,9 @@ def _generate_class_decls( # noqa: C901,PLR0912
519524 (inner_tp ,) = v .__args__
520525 type_ref = resolve_type_annotation (decls , inner_tp )
521526 cls_decl .class_variables [k ] = ConstantDecl (type_ref .to_just ())
522- _add_default_rewrite (decls , ClassVariableRef (cls_name , k ), type_ref , namespace .pop (k , None ), ruleset )
527+ _add_default_rewrite (
528+ decls , ClassVariableRef (cls_name , k ), type_ref , namespace .pop (k , None ), ruleset , subsume = False
529+ )
523530 else :
524531 msg = f"On class { cls_name } , for attribute '{ k } ', expected a ClassVar, but got { v } "
525532 raise NotImplementedError (msg )
@@ -542,12 +549,12 @@ def _generate_class_decls( # noqa: C901,PLR0912
542549 if is_init and cls_name in LIT_CLASS_NAMES :
543550 continue
544551 match method :
545- case _WrappedMethod (egg_fn , cost , default , merge , on_merge , fn , preserve , mutates , unextractable ):
552+ case _WrappedMethod (egg_fn , cost , default , merge , on_merge , fn , preserve , mutates , unextractable , subsume ):
546553 pass
547554 case _:
548555 egg_fn , cost , default , merge , on_merge = None , None , None , None , None
549556 fn = method
550- unextractable , preserve = False , False
557+ unextractable , preserve , subsume = False , False , False
551558 mutates = method_name in ALWAYS_MUTATES_SELF
552559 if preserve :
553560 cls_decl .preserved_methods [method_name ] = fn
@@ -572,7 +579,20 @@ def _generate_class_decls( # noqa: C901,PLR0912
572579 continue
573580
574581 _ , add_rewrite = _fn_decl (
575- decls , egg_fn , ref , fn , locals , default , cost , merge , on_merge , mutates , builtin , ruleset , unextractable
582+ decls ,
583+ egg_fn ,
584+ ref ,
585+ fn ,
586+ locals ,
587+ default ,
588+ cost ,
589+ merge ,
590+ on_merge ,
591+ mutates ,
592+ builtin ,
593+ ruleset = ruleset ,
594+ unextractable = unextractable ,
595+ subsume = subsume ,
576596 )
577597
578598 if not builtin and not isinstance (ref , InitRef ) and not mutates :
@@ -602,6 +622,7 @@ def function(
602622 builtin : bool = False ,
603623 ruleset : Ruleset | None = None ,
604624 use_body_as_name : bool = False ,
625+ subsume : bool = False ,
605626) -> Callable [[CALLABLE ], CALLABLE ]: ...
606627
607628
@@ -617,6 +638,7 @@ def function(
617638 unextractable : bool = False ,
618639 ruleset : Ruleset | None = None ,
619640 use_body_as_name : bool = False ,
641+ subsume : bool = False ,
620642) -> Callable [[Callable [P , EXPR ]], Callable [P , EXPR ]]: ...
621643
622644
@@ -649,6 +671,7 @@ class _FunctionConstructor:
649671 unextractable : bool = False
650672 ruleset : Ruleset | None = None
651673 use_body_as_name : bool = False
674+ subsume : bool = False
652675
653676 def __call__ (self , fn : Callable [..., RuntimeExpr ]) -> RuntimeFunction :
654677 return RuntimeFunction (* split_thunk (Thunk .fn (self .create_decls , fn )))
@@ -668,7 +691,8 @@ def create_decls(self, fn: Callable[..., RuntimeExpr]) -> tuple[Declarations, Ca
668691 self .on_merge ,
669692 self .mutates_first_arg ,
670693 self .builtin ,
671- self .ruleset ,
694+ ruleset = self .ruleset ,
695+ subsume = self .subsume ,
672696 unextractable = self .unextractable ,
673697 )
674698 add_rewrite ()
@@ -690,6 +714,7 @@ def _fn_decl(
690714 on_merge : Callable [[RuntimeExpr , RuntimeExpr ], Iterable [ActionLike ]] | None ,
691715 mutates_first_arg : bool ,
692716 is_builtin : bool ,
717+ subsume : bool ,
693718 ruleset : Ruleset | None = None ,
694719 unextractable : bool = False ,
695720) -> tuple [CallableRef , Callable [[], None ]]:
@@ -804,7 +829,7 @@ def _fn_decl(
804829 res_ref = ref
805830 decls .set_function_decl (ref , decl )
806831 res_thunk = Thunk .fn (_create_default_value , decls , ref , fn , args , ruleset )
807- return res_ref , Thunk .fn (_add_default_rewrite_function , decls , res_ref , return_type , ruleset , res_thunk )
832+ return res_ref , Thunk .fn (_add_default_rewrite_function , decls , res_ref , return_type , ruleset , res_thunk , subsume )
808833
809834
810835# Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
@@ -871,7 +896,7 @@ def _constant_thunk(
871896 type_ref = resolve_type_annotation (decls , tp )
872897 callable_ref = ConstantRef (name )
873898 decls ._constants [name ] = ConstantDecl (type_ref .to_just (), egg_name )
874- _add_default_rewrite (decls , callable_ref , type_ref , default_replacement , ruleset )
899+ _add_default_rewrite (decls , callable_ref , type_ref , default_replacement , ruleset , subsume = False )
875900 return decls , TypedExprDecl (type_ref .to_just (), CallDecl (callable_ref ))
876901
877902
@@ -898,15 +923,21 @@ def _add_default_rewrite_function(
898923 res_type : TypeOrVarRef ,
899924 ruleset : Ruleset | None ,
900925 value_thunk : Callable [[], object ],
926+ subsume : bool ,
901927) -> None :
902928 """
903929 Helper functions that resolves a value thunk to create the default value.
904930 """
905- _add_default_rewrite (decls , ref , res_type , value_thunk (), ruleset )
931+ _add_default_rewrite (decls , ref , res_type , value_thunk (), ruleset , subsume )
906932
907933
908934def _add_default_rewrite (
909- decls : Declarations , ref : CallableRef , type_ref : TypeOrVarRef , default_rewrite : object , ruleset : Ruleset | None
935+ decls : Declarations ,
936+ ref : CallableRef ,
937+ type_ref : TypeOrVarRef ,
938+ default_rewrite : object ,
939+ ruleset : Ruleset | None ,
940+ subsume : bool ,
910941) -> None :
911942 """
912943 Adds a default rewrite for the callable, if the default rewrite is not None
@@ -916,7 +947,7 @@ def _add_default_rewrite(
916947 if default_rewrite is None :
917948 return
918949 resolved_value = resolve_literal (type_ref , default_rewrite , Thunk .value (decls ))
919- rewrite_decl = DefaultRewriteDecl (ref , resolved_value .__egg_typed_expr__ .expr )
950+ rewrite_decl = DefaultRewriteDecl (ref , resolved_value .__egg_typed_expr__ .expr , subsume )
920951 if ruleset :
921952 ruleset_decls = ruleset ._current_egg_decls
922953 ruleset_decl = ruleset .__egg_ruleset__
@@ -1341,8 +1372,6 @@ def saturate(
13411372 from .visualizer_widget import VisualizerWidget
13421373
13431374 def to_json () -> str :
1344- if expr :
1345- print (self .extract (expr ))
13461375 return self ._serialize (** kwargs ).to_json ()
13471376
13481377 egraphs = [to_json ()]
@@ -1407,6 +1436,7 @@ class _WrappedMethod(Generic[P, EXPR]):
14071436 preserve : bool
14081437 mutates_self : bool
14091438 unextractable : bool
1439+ subsume : bool
14101440
14111441 def __call__ (self , * args : P .args , ** kwargs : P .kwargs ) -> EXPR :
14121442 msg = "We should never call a wrapped method. Did you forget to wrap the class?"
0 commit comments