2626
2727
2828__all__ = [
29+ "BigInt" ,
30+ "BigIntLike" ,
31+ "BigRat" ,
32+ "BigRatLike" ,
2933 "Bool" ,
3034 "BoolLike" ,
3135 "Map" ,
3236 "MapLike" ,
37+ "MultiSet" ,
3338 "PyObject" ,
3439 "Rational" ,
3540 "Set" ,
@@ -329,14 +334,14 @@ class Map(BuiltinExpr, Generic[T, V]):
329334 @method (preserve = True )
330335 def eval (self ) -> dict [T , V ]:
331336 call = _extract_call (self )
332- expr = cast (RuntimeExpr , self )
337+ expr = cast (" RuntimeExpr" , self )
333338 d = {}
334339 while call .callable != ClassMethodRef ("Map" , "empty" ):
335340 assert call .callable == MethodRef ("Map" , "insert" )
336341 call_typed , k_typed , v_typed = call .args
337342 assert isinstance (call_typed .expr , CallDecl )
338- k = cast (T , expr .__with_expr__ (k_typed ))
339- v = cast (V , expr .__with_expr__ (v_typed ))
343+ k = cast ("T" , expr .__with_expr__ (k_typed ))
344+ v = cast ("V" , expr .__with_expr__ (v_typed ))
340345 d [k ] = v
341346 call = call_typed .expr
342347 return d
@@ -397,7 +402,7 @@ class Set(BuiltinExpr, Generic[T]):
397402 def eval (self ) -> set [T ]:
398403 call = _extract_call (self )
399404 assert call .callable == InitRef ("Set" )
400- return {cast (T , cast (RuntimeExpr , self ).__with_expr__ (x )) for x in call .args }
405+ return {cast ("T" , cast (" RuntimeExpr" , self ).__with_expr__ (x )) for x in call .args }
401406
402407 @method (preserve = True )
403408 def __iter__ (self ) -> Iterator [T ]:
@@ -454,6 +459,53 @@ def rebuild(self) -> Set[T]: ...
454459SetLike : TypeAlias = Set [T ] | set [TO ]
455460
456461
462+ class MultiSet (BuiltinExpr , Generic [T ]):
463+ @method (preserve = True )
464+ def eval (self ) -> list [T ]:
465+ call = _extract_call (self )
466+ assert call .callable == InitRef ("MultiSet" )
467+ return [cast ("T" , cast ("RuntimeExpr" , self ).__with_expr__ (x )) for x in call .args ]
468+
469+ @method (preserve = True )
470+ def __iter__ (self ) -> Iterator [T ]:
471+ return iter (self .eval ())
472+
473+ @method (preserve = True )
474+ def __len__ (self ) -> int :
475+ return len (self .eval ())
476+
477+ @method (preserve = True )
478+ def __contains__ (self , key : T ) -> bool :
479+ return key in self .eval ()
480+
481+ @method (egg_fn = "multiset-of" )
482+ def __init__ (self , * args : T ) -> None : ...
483+
484+ @method (egg_fn = "multiset-insert" )
485+ def insert (self , value : T ) -> MultiSet [T ]: ...
486+
487+ @method (egg_fn = "multiset-not-contains" )
488+ def not_contains (self , value : T ) -> Unit : ...
489+
490+ @method (egg_fn = "multiset-contains" )
491+ def contains (self , value : T ) -> Unit : ...
492+
493+ @method (egg_fn = "multiset-remove" )
494+ def remove (self , value : T ) -> MultiSet [T ]: ...
495+
496+ @method (egg_fn = "multiset-length" )
497+ def length (self ) -> i64 : ...
498+
499+ @method (egg_fn = "multiset-pick" )
500+ def pick (self ) -> T : ...
501+
502+ @method (egg_fn = "multiset-sum" )
503+ def __add__ (self , other : MultiSet [T ]) -> MultiSet [T ]: ...
504+
505+ @method (egg_fn = "unstable-multiset-map" , reverse_args = True )
506+ def map (self , f : Callable [[T ], T ]) -> MultiSet [T ]: ...
507+
508+
457509class Rational (BuiltinExpr ):
458510 @method (preserve = True )
459511 def eval (self ) -> Fraction :
@@ -537,14 +589,237 @@ def numer(self) -> i64: ...
537589 def denom (self ) -> i64 : ...
538590
539591
592+ class BigInt (BuiltinExpr ):
593+ @method (preserve = True )
594+ def eval (self ) -> int :
595+ call = _extract_call (self )
596+ assert call .callable == ClassMethodRef ("BigInt" , "from_string" )
597+ (s ,) = call .args
598+ assert isinstance (s .expr , LitDecl )
599+ assert isinstance (s .expr .value , str )
600+ return int (s .expr .value )
601+
602+ @method (preserve = True )
603+ def __index__ (self ) -> int :
604+ return self .eval ()
605+
606+ @method (preserve = True )
607+ def __int__ (self ) -> int :
608+ return self .eval ()
609+
610+ @method (egg_fn = "from-string" )
611+ @classmethod
612+ def from_string (cls , s : StringLike ) -> BigInt : ...
613+
614+ @method (egg_fn = "bigint" )
615+ def __init__ (self , value : i64Like ) -> None : ...
616+
617+ @method (egg_fn = "+" )
618+ def __add__ (self , other : BigIntLike ) -> BigInt : ...
619+
620+ @method (egg_fn = "-" )
621+ def __sub__ (self , other : BigIntLike ) -> BigInt : ...
622+
623+ @method (egg_fn = "*" )
624+ def __mul__ (self , other : BigIntLike ) -> BigInt : ...
625+
626+ @method (egg_fn = "/" )
627+ def __truediv__ (self , other : BigIntLike ) -> BigInt : ...
628+
629+ @method (egg_fn = "%" )
630+ def __mod__ (self , other : BigIntLike ) -> BigInt : ...
631+
632+ @method (egg_fn = "&" )
633+ def __and__ (self , other : BigIntLike ) -> BigInt : ...
634+
635+ @method (egg_fn = "|" )
636+ def __or__ (self , other : BigIntLike ) -> BigInt : ...
637+
638+ @method (egg_fn = "^" )
639+ def __xor__ (self , other : BigIntLike ) -> BigInt : ...
640+
641+ @method (egg_fn = "<<" )
642+ def __lshift__ (self , other : i64Like ) -> BigInt : ...
643+
644+ @method (egg_fn = ">>" )
645+ def __rshift__ (self , other : i64Like ) -> BigInt : ...
646+
647+ def __radd__ (self , other : BigIntLike ) -> BigInt : ...
648+
649+ def __rsub__ (self , other : BigIntLike ) -> BigInt : ...
650+
651+ def __rmul__ (self , other : BigIntLike ) -> BigInt : ...
652+
653+ def __rtruediv__ (self , other : BigIntLike ) -> BigInt : ...
654+
655+ def __rmod__ (self , other : BigIntLike ) -> BigInt : ...
656+
657+ def __rand__ (self , other : BigIntLike ) -> BigInt : ...
658+
659+ def __ror__ (self , other : BigIntLike ) -> BigInt : ...
660+
661+ def __rxor__ (self , other : BigIntLike ) -> BigInt : ...
662+
663+ @method (egg_fn = "not-Z" )
664+ def __invert__ (self ) -> BigInt : ...
665+
666+ @method (egg_fn = "bits" )
667+ def bits (self ) -> BigInt : ...
668+
669+ @method (egg_fn = "<" )
670+ def __lt__ (self , other : BigIntLike ) -> Unit : # type: ignore[empty-body,has-type]
671+ ...
672+
673+ @method (egg_fn = ">" )
674+ def __gt__ (self , other : BigIntLike ) -> Unit : ...
675+
676+ @method (egg_fn = "<=" )
677+ def __le__ (self , other : BigIntLike ) -> Unit : # type: ignore[empty-body,has-type]
678+ ...
679+
680+ @method (egg_fn = ">=" )
681+ def __ge__ (self , other : BigIntLike ) -> Unit : ...
682+
683+ @method (egg_fn = "min" )
684+ def min (self , other : BigIntLike ) -> BigInt : ...
685+
686+ @method (egg_fn = "max" )
687+ def max (self , other : BigIntLike ) -> BigInt : ...
688+
689+ @method (egg_fn = "to-string" )
690+ def to_string (self ) -> String : ...
691+
692+ @method (egg_fn = "bool-=" )
693+ def bool_eq (self , other : BigIntLike ) -> Bool : ...
694+
695+ @method (egg_fn = "bool-<" )
696+ def bool_lt (self , other : BigIntLike ) -> Bool : ...
697+
698+ @method (egg_fn = "bool->" )
699+ def bool_gt (self , other : BigIntLike ) -> Bool : ...
700+
701+ @method (egg_fn = "bool-<=" )
702+ def bool_le (self , other : BigIntLike ) -> Bool : ...
703+
704+ @method (egg_fn = "bool->=" )
705+ def bool_ge (self , other : BigIntLike ) -> Bool : ...
706+
707+
708+ converter (i64 , BigInt , lambda i : BigInt (i ))
709+
710+ BigIntLike : TypeAlias = BigInt | i64Like
711+
712+
713+ class BigRat (BuiltinExpr ):
714+ @method (preserve = True )
715+ def eval (self ) -> Fraction :
716+ call = _extract_call (self )
717+ assert call .callable == InitRef ("BigRat" )
718+
719+ def _to_fraction (e : TypedExprDecl ) -> Fraction :
720+ expr = e .expr
721+ assert isinstance (expr , CallDecl )
722+ assert expr .callable == ClassMethodRef ("BigInt" , "from_string" )
723+ (s ,) = expr .args
724+ assert isinstance (s .expr , LitDecl )
725+ assert isinstance (s .expr .value , str )
726+ return Fraction (s .expr .value )
727+
728+ num , den = call .args
729+ return Fraction (_to_fraction (num ), _to_fraction (den ))
730+
731+ @method (preserve = True )
732+ def __float__ (self ) -> float :
733+ return float (self .eval ())
734+
735+ @method (preserve = True )
736+ def __int__ (self ) -> int :
737+ return int (self .eval ())
738+
739+ @method (egg_fn = "bigrat" )
740+ def __init__ (self , num : BigIntLike , den : BigIntLike ) -> None : ...
741+
742+ @method (egg_fn = "to-f64" )
743+ def to_f64 (self ) -> f64 : ...
744+
745+ @method (egg_fn = "+" )
746+ def __add__ (self , other : BigRatLike ) -> BigRat : ...
747+
748+ @method (egg_fn = "-" )
749+ def __sub__ (self , other : BigRatLike ) -> BigRat : ...
750+
751+ @method (egg_fn = "*" )
752+ def __mul__ (self , other : BigRatLike ) -> BigRat : ...
753+
754+ @method (egg_fn = "/" )
755+ def __truediv__ (self , other : BigRatLike ) -> BigRat : ...
756+
757+ @method (egg_fn = "min" )
758+ def min (self , other : BigRatLike ) -> BigRat : ...
759+
760+ @method (egg_fn = "max" )
761+ def max (self , other : BigRatLike ) -> BigRat : ...
762+
763+ @method (egg_fn = "neg" )
764+ def __neg__ (self ) -> BigRat : ...
765+
766+ @method (egg_fn = "abs" )
767+ def __abs__ (self ) -> BigRat : ...
768+
769+ @method (egg_fn = "floor" )
770+ def floor (self ) -> BigRat : ...
771+
772+ @method (egg_fn = "ceil" )
773+ def ceil (self ) -> BigRat : ...
774+
775+ @method (egg_fn = "round" )
776+ def round (self ) -> BigRat : ...
777+
778+ @method (egg_fn = "pow" )
779+ def __pow__ (self , other : BigRatLike ) -> BigRat : ...
780+
781+ @method (egg_fn = "log" )
782+ def log (self ) -> BigRat : ...
783+
784+ @method (egg_fn = "sqrt" )
785+ def sqrt (self ) -> BigRat : ...
786+
787+ @method (egg_fn = "cbrt" )
788+ def cbrt (self ) -> BigRat : ...
789+
790+ @method (egg_fn = "numer" ) # type: ignore[misc]
791+ @property
792+ def numer (self ) -> BigInt : ...
793+
794+ @method (egg_fn = "denom" ) # type: ignore[misc]
795+ @property
796+ def denom (self ) -> BigInt : ...
797+
798+ @method (egg_fn = "<" )
799+ def __lt__ (self , other : BigRatLike ) -> Unit : ... # type: ignore[has-type]
800+
801+ @method (egg_fn = ">" )
802+ def __gt__ (self , other : BigRatLike ) -> Unit : ...
803+
804+ @method (egg_fn = ">=" )
805+ def __ge__ (self , other : BigRatLike ) -> Unit : ... # type: ignore[has-type]
806+
807+ @method (egg_fn = "<=" )
808+ def __le__ (self , other : BigRatLike ) -> Unit : ...
809+
810+
811+ converter (Fraction , BigRat , lambda f : BigRat (f .numerator , f .denominator ))
812+ BigRatLike : TypeAlias = BigRat | Fraction
813+
814+
540815class Vec (BuiltinExpr , Generic [T ]):
541816 @method (preserve = True )
542817 def eval (self ) -> tuple [T , ...]:
543818 call = _extract_call (self )
544819 if call .callable == ClassMethodRef ("Vec" , "empty" ):
545820 return ()
546821 assert call .callable == InitRef ("Vec" )
547- return tuple (cast (T , cast (RuntimeExpr , self ).__with_expr__ (x )) for x in call .args )
822+ return tuple (cast ("T" , cast (" RuntimeExpr" , self ).__with_expr__ (x )) for x in call .args )
548823
549824 @method (preserve = True )
550825 def __iter__ (self ) -> Iterator [T ]:
@@ -611,7 +886,7 @@ def set(self, index: i64Like, value: T) -> Vec[T]: ...
611886class PyObject (BuiltinExpr ):
612887 @method (preserve = True )
613888 def eval (self ) -> object :
614- report = (EGraph .current or EGraph ())._run_extract (cast (RuntimeExpr , self ), 0 )
889+ report = (EGraph .current or EGraph ())._run_extract (cast (" RuntimeExpr" , self ), 0 )
615890 assert isinstance (report , bindings .Best )
616891 expr = report .termdag .term_to_expr (report .term , bindings .PanicSpan ())
617892 return GLOBAL_PY_OBJECT_SORT .load (expr )
@@ -743,7 +1018,7 @@ def value_to_annotation(a: object) -> type | None:
7431018 # only lift runtime expressions (which could contain vars) not any other nonlocals/globals we use in the function
7441019 if not isinstance (a , RuntimeExpr ):
7451020 return None
746- return cast (type , RuntimeClass (Thunk .value (a .__egg_decls__ ), a .__egg_typed_expr__ .tp .to_var ()))
1021+ return cast (" type" , RuntimeClass (Thunk .value (a .__egg_decls__ ), a .__egg_typed_expr__ .tp .to_var ()))
7471022
7481023
7491024converter (FunctionType , UnstableFn , _convert_function )
@@ -753,7 +1028,7 @@ def _extract_lit(e: BaseExpr) -> bindings._Literal:
7531028 """
7541029 Special case extracting literals to make this faster by using termdag directly.
7551030 """
756- report = (EGraph .current or EGraph ())._run_extract (cast (RuntimeExpr , e ), 0 )
1031+ report = (EGraph .current or EGraph ())._run_extract (cast (" RuntimeExpr" , e ), 0 )
7571032 assert isinstance (report , bindings .Best )
7581033 term = report .term
7591034 assert isinstance (term , bindings .TermLit )
@@ -764,7 +1039,7 @@ def _extract_call(e: BaseExpr) -> CallDecl:
7641039 """
7651040 Extracts the call form of an expression
7661041 """
767- extracted = cast (RuntimeExpr , (EGraph .current or EGraph ()).extract (e ))
1042+ extracted = cast (" RuntimeExpr" , (EGraph .current or EGraph ()).extract (e ))
7681043 expr = extracted .__egg_typed_expr__ .expr
7691044 assert isinstance (expr , CallDecl )
7701045 return expr
0 commit comments