26
26
27
27
28
28
__all__ = [
29
+ "BigInt" ,
30
+ "BigIntLike" ,
31
+ "BigRat" ,
32
+ "BigRatLike" ,
29
33
"Bool" ,
30
34
"BoolLike" ,
31
35
"Map" ,
32
36
"MapLike" ,
37
+ "MultiSet" ,
33
38
"PyObject" ,
34
39
"Rational" ,
35
40
"Set" ,
@@ -329,14 +334,14 @@ class Map(BuiltinExpr, Generic[T, V]):
329
334
@method (preserve = True )
330
335
def eval (self ) -> dict [T , V ]:
331
336
call = _extract_call (self )
332
- expr = cast (RuntimeExpr , self )
337
+ expr = cast (" RuntimeExpr" , self )
333
338
d = {}
334
339
while call .callable != ClassMethodRef ("Map" , "empty" ):
335
340
assert call .callable == MethodRef ("Map" , "insert" )
336
341
call_typed , k_typed , v_typed = call .args
337
342
assert isinstance (call_typed .expr , CallDecl )
338
- k = cast (T , expr .__with_expr__ (k_typed ))
339
- v = cast (V , expr .__with_expr__ (v_typed ))
343
+ k = cast ("T" , expr .__with_expr__ (k_typed ))
344
+ v = cast ("V" , expr .__with_expr__ (v_typed ))
340
345
d [k ] = v
341
346
call = call_typed .expr
342
347
return d
@@ -397,7 +402,7 @@ class Set(BuiltinExpr, Generic[T]):
397
402
def eval (self ) -> set [T ]:
398
403
call = _extract_call (self )
399
404
assert call .callable == InitRef ("Set" )
400
- return {cast (T , cast (RuntimeExpr , self ).__with_expr__ (x )) for x in call .args }
405
+ return {cast ("T" , cast (" RuntimeExpr" , self ).__with_expr__ (x )) for x in call .args }
401
406
402
407
@method (preserve = True )
403
408
def __iter__ (self ) -> Iterator [T ]:
@@ -454,6 +459,53 @@ def rebuild(self) -> Set[T]: ...
454
459
SetLike : TypeAlias = Set [T ] | set [TO ]
455
460
456
461
462
+ class MultiSet (BuiltinExpr , Generic [T ]):
463
+ @method (preserve = True )
464
+ def eval (self ) -> list [T ]:
465
+ call = _extract_call (self )
466
+ assert call .callable == InitRef ("MultiSet" )
467
+ return [cast ("T" , cast ("RuntimeExpr" , self ).__with_expr__ (x )) for x in call .args ]
468
+
469
+ @method (preserve = True )
470
+ def __iter__ (self ) -> Iterator [T ]:
471
+ return iter (self .eval ())
472
+
473
+ @method (preserve = True )
474
+ def __len__ (self ) -> int :
475
+ return len (self .eval ())
476
+
477
+ @method (preserve = True )
478
+ def __contains__ (self , key : T ) -> bool :
479
+ return key in self .eval ()
480
+
481
+ @method (egg_fn = "multiset-of" )
482
+ def __init__ (self , * args : T ) -> None : ...
483
+
484
+ @method (egg_fn = "multiset-insert" )
485
+ def insert (self , value : T ) -> MultiSet [T ]: ...
486
+
487
+ @method (egg_fn = "multiset-not-contains" )
488
+ def not_contains (self , value : T ) -> Unit : ...
489
+
490
+ @method (egg_fn = "multiset-contains" )
491
+ def contains (self , value : T ) -> Unit : ...
492
+
493
+ @method (egg_fn = "multiset-remove" )
494
+ def remove (self , value : T ) -> MultiSet [T ]: ...
495
+
496
+ @method (egg_fn = "multiset-length" )
497
+ def length (self ) -> i64 : ...
498
+
499
+ @method (egg_fn = "multiset-pick" )
500
+ def pick (self ) -> T : ...
501
+
502
+ @method (egg_fn = "multiset-sum" )
503
+ def __add__ (self , other : MultiSet [T ]) -> MultiSet [T ]: ...
504
+
505
+ @method (egg_fn = "unstable-multiset-map" , reverse_args = True )
506
+ def map (self , f : Callable [[T ], T ]) -> MultiSet [T ]: ...
507
+
508
+
457
509
class Rational (BuiltinExpr ):
458
510
@method (preserve = True )
459
511
def eval (self ) -> Fraction :
@@ -537,14 +589,237 @@ def numer(self) -> i64: ...
537
589
def denom (self ) -> i64 : ...
538
590
539
591
592
+ class BigInt (BuiltinExpr ):
593
+ @method (preserve = True )
594
+ def eval (self ) -> int :
595
+ call = _extract_call (self )
596
+ assert call .callable == ClassMethodRef ("BigInt" , "from_string" )
597
+ (s ,) = call .args
598
+ assert isinstance (s .expr , LitDecl )
599
+ assert isinstance (s .expr .value , str )
600
+ return int (s .expr .value )
601
+
602
+ @method (preserve = True )
603
+ def __index__ (self ) -> int :
604
+ return self .eval ()
605
+
606
+ @method (preserve = True )
607
+ def __int__ (self ) -> int :
608
+ return self .eval ()
609
+
610
+ @method (egg_fn = "from-string" )
611
+ @classmethod
612
+ def from_string (cls , s : StringLike ) -> BigInt : ...
613
+
614
+ @method (egg_fn = "bigint" )
615
+ def __init__ (self , value : i64Like ) -> None : ...
616
+
617
+ @method (egg_fn = "+" )
618
+ def __add__ (self , other : BigIntLike ) -> BigInt : ...
619
+
620
+ @method (egg_fn = "-" )
621
+ def __sub__ (self , other : BigIntLike ) -> BigInt : ...
622
+
623
+ @method (egg_fn = "*" )
624
+ def __mul__ (self , other : BigIntLike ) -> BigInt : ...
625
+
626
+ @method (egg_fn = "/" )
627
+ def __truediv__ (self , other : BigIntLike ) -> BigInt : ...
628
+
629
+ @method (egg_fn = "%" )
630
+ def __mod__ (self , other : BigIntLike ) -> BigInt : ...
631
+
632
+ @method (egg_fn = "&" )
633
+ def __and__ (self , other : BigIntLike ) -> BigInt : ...
634
+
635
+ @method (egg_fn = "|" )
636
+ def __or__ (self , other : BigIntLike ) -> BigInt : ...
637
+
638
+ @method (egg_fn = "^" )
639
+ def __xor__ (self , other : BigIntLike ) -> BigInt : ...
640
+
641
+ @method (egg_fn = "<<" )
642
+ def __lshift__ (self , other : i64Like ) -> BigInt : ...
643
+
644
+ @method (egg_fn = ">>" )
645
+ def __rshift__ (self , other : i64Like ) -> BigInt : ...
646
+
647
+ def __radd__ (self , other : BigIntLike ) -> BigInt : ...
648
+
649
+ def __rsub__ (self , other : BigIntLike ) -> BigInt : ...
650
+
651
+ def __rmul__ (self , other : BigIntLike ) -> BigInt : ...
652
+
653
+ def __rtruediv__ (self , other : BigIntLike ) -> BigInt : ...
654
+
655
+ def __rmod__ (self , other : BigIntLike ) -> BigInt : ...
656
+
657
+ def __rand__ (self , other : BigIntLike ) -> BigInt : ...
658
+
659
+ def __ror__ (self , other : BigIntLike ) -> BigInt : ...
660
+
661
+ def __rxor__ (self , other : BigIntLike ) -> BigInt : ...
662
+
663
+ @method (egg_fn = "not-Z" )
664
+ def __invert__ (self ) -> BigInt : ...
665
+
666
+ @method (egg_fn = "bits" )
667
+ def bits (self ) -> BigInt : ...
668
+
669
+ @method (egg_fn = "<" )
670
+ def __lt__ (self , other : BigIntLike ) -> Unit : # type: ignore[empty-body,has-type]
671
+ ...
672
+
673
+ @method (egg_fn = ">" )
674
+ def __gt__ (self , other : BigIntLike ) -> Unit : ...
675
+
676
+ @method (egg_fn = "<=" )
677
+ def __le__ (self , other : BigIntLike ) -> Unit : # type: ignore[empty-body,has-type]
678
+ ...
679
+
680
+ @method (egg_fn = ">=" )
681
+ def __ge__ (self , other : BigIntLike ) -> Unit : ...
682
+
683
+ @method (egg_fn = "min" )
684
+ def min (self , other : BigIntLike ) -> BigInt : ...
685
+
686
+ @method (egg_fn = "max" )
687
+ def max (self , other : BigIntLike ) -> BigInt : ...
688
+
689
+ @method (egg_fn = "to-string" )
690
+ def to_string (self ) -> String : ...
691
+
692
+ @method (egg_fn = "bool-=" )
693
+ def bool_eq (self , other : BigIntLike ) -> Bool : ...
694
+
695
+ @method (egg_fn = "bool-<" )
696
+ def bool_lt (self , other : BigIntLike ) -> Bool : ...
697
+
698
+ @method (egg_fn = "bool->" )
699
+ def bool_gt (self , other : BigIntLike ) -> Bool : ...
700
+
701
+ @method (egg_fn = "bool-<=" )
702
+ def bool_le (self , other : BigIntLike ) -> Bool : ...
703
+
704
+ @method (egg_fn = "bool->=" )
705
+ def bool_ge (self , other : BigIntLike ) -> Bool : ...
706
+
707
+
708
+ converter (i64 , BigInt , lambda i : BigInt (i ))
709
+
710
+ BigIntLike : TypeAlias = BigInt | i64Like
711
+
712
+
713
+ class BigRat (BuiltinExpr ):
714
+ @method (preserve = True )
715
+ def eval (self ) -> Fraction :
716
+ call = _extract_call (self )
717
+ assert call .callable == InitRef ("BigRat" )
718
+
719
+ def _to_fraction (e : TypedExprDecl ) -> Fraction :
720
+ expr = e .expr
721
+ assert isinstance (expr , CallDecl )
722
+ assert expr .callable == ClassMethodRef ("BigInt" , "from_string" )
723
+ (s ,) = expr .args
724
+ assert isinstance (s .expr , LitDecl )
725
+ assert isinstance (s .expr .value , str )
726
+ return Fraction (s .expr .value )
727
+
728
+ num , den = call .args
729
+ return Fraction (_to_fraction (num ), _to_fraction (den ))
730
+
731
+ @method (preserve = True )
732
+ def __float__ (self ) -> float :
733
+ return float (self .eval ())
734
+
735
+ @method (preserve = True )
736
+ def __int__ (self ) -> int :
737
+ return int (self .eval ())
738
+
739
+ @method (egg_fn = "bigrat" )
740
+ def __init__ (self , num : BigIntLike , den : BigIntLike ) -> None : ...
741
+
742
+ @method (egg_fn = "to-f64" )
743
+ def to_f64 (self ) -> f64 : ...
744
+
745
+ @method (egg_fn = "+" )
746
+ def __add__ (self , other : BigRatLike ) -> BigRat : ...
747
+
748
+ @method (egg_fn = "-" )
749
+ def __sub__ (self , other : BigRatLike ) -> BigRat : ...
750
+
751
+ @method (egg_fn = "*" )
752
+ def __mul__ (self , other : BigRatLike ) -> BigRat : ...
753
+
754
+ @method (egg_fn = "/" )
755
+ def __truediv__ (self , other : BigRatLike ) -> BigRat : ...
756
+
757
+ @method (egg_fn = "min" )
758
+ def min (self , other : BigRatLike ) -> BigRat : ...
759
+
760
+ @method (egg_fn = "max" )
761
+ def max (self , other : BigRatLike ) -> BigRat : ...
762
+
763
+ @method (egg_fn = "neg" )
764
+ def __neg__ (self ) -> BigRat : ...
765
+
766
+ @method (egg_fn = "abs" )
767
+ def __abs__ (self ) -> BigRat : ...
768
+
769
+ @method (egg_fn = "floor" )
770
+ def floor (self ) -> BigRat : ...
771
+
772
+ @method (egg_fn = "ceil" )
773
+ def ceil (self ) -> BigRat : ...
774
+
775
+ @method (egg_fn = "round" )
776
+ def round (self ) -> BigRat : ...
777
+
778
+ @method (egg_fn = "pow" )
779
+ def __pow__ (self , other : BigRatLike ) -> BigRat : ...
780
+
781
+ @method (egg_fn = "log" )
782
+ def log (self ) -> BigRat : ...
783
+
784
+ @method (egg_fn = "sqrt" )
785
+ def sqrt (self ) -> BigRat : ...
786
+
787
+ @method (egg_fn = "cbrt" )
788
+ def cbrt (self ) -> BigRat : ...
789
+
790
+ @method (egg_fn = "numer" ) # type: ignore[misc]
791
+ @property
792
+ def numer (self ) -> BigInt : ...
793
+
794
+ @method (egg_fn = "denom" ) # type: ignore[misc]
795
+ @property
796
+ def denom (self ) -> BigInt : ...
797
+
798
+ @method (egg_fn = "<" )
799
+ def __lt__ (self , other : BigRatLike ) -> Unit : ... # type: ignore[has-type]
800
+
801
+ @method (egg_fn = ">" )
802
+ def __gt__ (self , other : BigRatLike ) -> Unit : ...
803
+
804
+ @method (egg_fn = ">=" )
805
+ def __ge__ (self , other : BigRatLike ) -> Unit : ... # type: ignore[has-type]
806
+
807
+ @method (egg_fn = "<=" )
808
+ def __le__ (self , other : BigRatLike ) -> Unit : ...
809
+
810
+
811
+ converter (Fraction , BigRat , lambda f : BigRat (f .numerator , f .denominator ))
812
+ BigRatLike : TypeAlias = BigRat | Fraction
813
+
814
+
540
815
class Vec (BuiltinExpr , Generic [T ]):
541
816
@method (preserve = True )
542
817
def eval (self ) -> tuple [T , ...]:
543
818
call = _extract_call (self )
544
819
if call .callable == ClassMethodRef ("Vec" , "empty" ):
545
820
return ()
546
821
assert call .callable == InitRef ("Vec" )
547
- return tuple (cast (T , cast (RuntimeExpr , self ).__with_expr__ (x )) for x in call .args )
822
+ return tuple (cast ("T" , cast (" RuntimeExpr" , self ).__with_expr__ (x )) for x in call .args )
548
823
549
824
@method (preserve = True )
550
825
def __iter__ (self ) -> Iterator [T ]:
@@ -611,7 +886,7 @@ def set(self, index: i64Like, value: T) -> Vec[T]: ...
611
886
class PyObject (BuiltinExpr ):
612
887
@method (preserve = True )
613
888
def eval (self ) -> object :
614
- report = (EGraph .current or EGraph ())._run_extract (cast (RuntimeExpr , self ), 0 )
889
+ report = (EGraph .current or EGraph ())._run_extract (cast (" RuntimeExpr" , self ), 0 )
615
890
assert isinstance (report , bindings .Best )
616
891
expr = report .termdag .term_to_expr (report .term , bindings .PanicSpan ())
617
892
return GLOBAL_PY_OBJECT_SORT .load (expr )
@@ -743,7 +1018,7 @@ def value_to_annotation(a: object) -> type | None:
743
1018
# only lift runtime expressions (which could contain vars) not any other nonlocals/globals we use in the function
744
1019
if not isinstance (a , RuntimeExpr ):
745
1020
return None
746
- return cast (type , RuntimeClass (Thunk .value (a .__egg_decls__ ), a .__egg_typed_expr__ .tp .to_var ()))
1021
+ return cast (" type" , RuntimeClass (Thunk .value (a .__egg_decls__ ), a .__egg_typed_expr__ .tp .to_var ()))
747
1022
748
1023
749
1024
converter (FunctionType , UnstableFn , _convert_function )
@@ -753,7 +1028,7 @@ def _extract_lit(e: BaseExpr) -> bindings._Literal:
753
1028
"""
754
1029
Special case extracting literals to make this faster by using termdag directly.
755
1030
"""
756
- report = (EGraph .current or EGraph ())._run_extract (cast (RuntimeExpr , e ), 0 )
1031
+ report = (EGraph .current or EGraph ())._run_extract (cast (" RuntimeExpr" , e ), 0 )
757
1032
assert isinstance (report , bindings .Best )
758
1033
term = report .term
759
1034
assert isinstance (term , bindings .TermLit )
@@ -764,7 +1039,7 @@ def _extract_call(e: BaseExpr) -> CallDecl:
764
1039
"""
765
1040
Extracts the call form of an expression
766
1041
"""
767
- extracted = cast (RuntimeExpr , (EGraph .current or EGraph ()).extract (e ))
1042
+ extracted = cast (" RuntimeExpr" , (EGraph .current or EGraph ()).extract (e ))
768
1043
expr = extracted .__egg_typed_expr__ .expr
769
1044
assert isinstance (expr , CallDecl )
770
1045
return expr
0 commit comments