12
12
13
13
from typing_extensions import TypeVarTuple , Unpack
14
14
15
- from . import bindings
16
15
from .conversion import convert , converter , get_type_args
17
16
from .declarations import *
18
- from .egraph import BaseExpr , BuiltinExpr , EGraph , expr_fact , function , get_current_ruleset , method
19
- from .egraph_state import GLOBAL_PY_OBJECT_SORT
17
+ from .egraph import BaseExpr , BuiltinExpr , expr_fact , function , get_current_ruleset , method
20
18
from .functionalize import functionalize
21
19
from .runtime import RuntimeClass , RuntimeExpr , RuntimeFunction
22
20
from .thunk import Thunk
32
30
"BigRatLike" ,
33
31
"Bool" ,
34
32
"BoolLike" ,
33
+ "BuiltinEvalError" ,
35
34
"Map" ,
36
35
"MapLike" ,
37
36
"MultiSet" ,
56
55
]
57
56
58
57
58
+ class BuiltinEvalError (Exception ):
59
+ """
60
+ Raised when an builtin cannot be evaluated into a Python primitive because it is complex.
61
+
62
+ Try extracting this expression first.
63
+ """
64
+
65
+ def __str__ (self ) -> str :
66
+ return f"Cannot evaluate builtin expression into a Python primitive. Try extracting this expression first: { super ().__str__ ()} "
67
+
68
+
59
69
class Unit (BuiltinExpr , egg_sort = "Unit" ):
60
70
"""
61
71
The unit type. This is used to reprsent if a value exists in the e-graph or not.
@@ -72,8 +82,8 @@ class String(BuiltinExpr):
72
82
@method (preserve = True )
73
83
def eval (self ) -> str :
74
84
value = _extract_lit (self )
75
- assert isinstance (value , bindings . String )
76
- return value . value
85
+ assert isinstance (value , str )
86
+ return value
77
87
78
88
def __init__ (self , value : str ) -> None : ...
79
89
@@ -97,8 +107,8 @@ class Bool(BuiltinExpr, egg_sort="bool"):
97
107
@method (preserve = True )
98
108
def eval (self ) -> bool :
99
109
value = _extract_lit (self )
100
- assert isinstance (value , bindings . Bool )
101
- return value . value
110
+ assert isinstance (value , bool )
111
+ return value
102
112
103
113
@method (preserve = True )
104
114
def __bool__ (self ) -> bool :
@@ -132,8 +142,8 @@ class i64(BuiltinExpr): # noqa: N801
132
142
@method (preserve = True )
133
143
def eval (self ) -> int :
134
144
value = _extract_lit (self )
135
- assert isinstance (value , bindings . Int )
136
- return value . value
145
+ assert isinstance (value , int )
146
+ return value
137
147
138
148
@method (preserve = True )
139
149
def __index__ (self ) -> int :
@@ -251,8 +261,8 @@ class f64(BuiltinExpr): # noqa: N801
251
261
@method (preserve = True )
252
262
def eval (self ) -> float :
253
263
value = _extract_lit (self )
254
- assert isinstance (value , bindings . Float )
255
- return value . value
264
+ assert isinstance (value , float )
265
+ return value
256
266
257
267
@method (preserve = True )
258
268
def __float__ (self ) -> float :
@@ -340,9 +350,12 @@ def eval(self) -> dict[T, V]:
340
350
expr = cast ("RuntimeExpr" , self )
341
351
d = {}
342
352
while call .callable != ClassMethodRef ("Map" , "empty" ):
343
- assert call .callable == MethodRef ("Map" , "insert" )
353
+ msg = "Map can only be evaluated if it is empty or a series of inserts."
354
+ if call .callable != MethodRef ("Map" , "insert" ):
355
+ raise BuiltinEvalError (msg )
344
356
call_typed , k_typed , v_typed = call .args
345
- assert isinstance (call_typed .expr , CallDecl )
357
+ if not isinstance (call_typed .expr , CallDecl ):
358
+ raise BuiltinEvalError (msg )
346
359
k = cast ("T" , expr .__with_expr__ (k_typed ))
347
360
v = cast ("V" , expr .__with_expr__ (v_typed ))
348
361
d [k ] = v
@@ -404,7 +417,9 @@ class Set(BuiltinExpr, Generic[T]):
404
417
@method (preserve = True )
405
418
def eval (self ) -> set [T ]:
406
419
call = _extract_call (self )
407
- assert call .callable == InitRef ("Set" )
420
+ if call .callable != InitRef ("Set" ):
421
+ msg = "Set can only be initialized with the Set constructor."
422
+ raise BuiltinEvalError (msg )
408
423
return {cast ("T" , cast ("RuntimeExpr" , self ).__with_expr__ (x )) for x in call .args }
409
424
410
425
@method (preserve = True )
@@ -466,7 +481,9 @@ class MultiSet(BuiltinExpr, Generic[T]):
466
481
@method (preserve = True )
467
482
def eval (self ) -> list [T ]:
468
483
call = _extract_call (self )
469
- assert call .callable == InitRef ("MultiSet" )
484
+ if call .callable != InitRef ("MultiSet" ):
485
+ msg = "MultiSet can only be initialized with the MultiSet constructor."
486
+ raise BuiltinEvalError (msg )
470
487
return [cast ("T" , cast ("RuntimeExpr" , self ).__with_expr__ (x )) for x in call .args ]
471
488
472
489
@method (preserve = True )
@@ -513,11 +530,15 @@ class Rational(BuiltinExpr):
513
530
@method (preserve = True )
514
531
def eval (self ) -> Fraction :
515
532
call = _extract_call (self )
516
- assert call .callable == InitRef ("Rational" )
533
+ if call .callable != InitRef ("Rational" ):
534
+ msg = "Rational can only be initialized with the Rational constructor."
535
+ raise BuiltinEvalError (msg )
517
536
518
537
def _to_int (e : TypedExprDecl ) -> int :
519
538
expr = e .expr
520
- assert isinstance (expr , LitDecl )
539
+ if not isinstance (expr , LitDecl ):
540
+ msg = "Rational can only be initialized with literals"
541
+ raise BuiltinEvalError (msg )
521
542
assert isinstance (expr .value , int )
522
543
return expr .value
523
544
@@ -596,9 +617,13 @@ class BigInt(BuiltinExpr):
596
617
@method (preserve = True )
597
618
def eval (self ) -> int :
598
619
call = _extract_call (self )
599
- assert call .callable == ClassMethodRef ("BigInt" , "from_string" )
620
+ if call .callable != ClassMethodRef ("BigInt" , "from_string" ):
621
+ msg = "BigInt can only be initialized with the BigInt constructor."
622
+ raise BuiltinEvalError (msg )
600
623
(s ,) = call .args
601
- assert isinstance (s .expr , LitDecl )
624
+ if not isinstance (s .expr , LitDecl ):
625
+ msg = "BigInt can only be initialized with literals"
626
+ raise BuiltinEvalError (msg )
602
627
assert isinstance (s .expr .value , str )
603
628
return int (s .expr .value )
604
629
@@ -717,14 +742,19 @@ class BigRat(BuiltinExpr):
717
742
@method (preserve = True )
718
743
def eval (self ) -> Fraction :
719
744
call = _extract_call (self )
720
- assert call .callable == InitRef ("BigRat" )
745
+ if call .callable != InitRef ("BigRat" ):
746
+ msg = "BigRat can only be initialized with the BigRat constructor."
747
+ raise BuiltinEvalError (msg )
721
748
722
749
def _to_fraction (e : TypedExprDecl ) -> Fraction :
723
750
expr = e .expr
724
- assert isinstance (expr , CallDecl )
725
- assert expr .callable == ClassMethodRef ("BigInt" , "from_string" )
751
+ if not isinstance (expr , CallDecl ) or expr .callable != ClassMethodRef ("BigInt" , "from_string" ):
752
+ msg = "BigRat can only be initialized BigInt strings"
753
+ raise BuiltinEvalError (msg )
726
754
(s ,) = expr .args
727
- assert isinstance (s .expr , LitDecl )
755
+ if not isinstance (s .expr , LitDecl ):
756
+ msg = "BigInt can only be initialized with literals"
757
+ raise BuiltinEvalError (msg )
728
758
assert isinstance (s .expr .value , str )
729
759
return Fraction (s .expr .value )
730
760
@@ -821,7 +851,10 @@ def eval(self) -> tuple[T, ...]:
821
851
call = _extract_call (self )
822
852
if call .callable == ClassMethodRef ("Vec" , "empty" ):
823
853
return ()
824
- assert call .callable == InitRef ("Vec" )
854
+
855
+ if call .callable != InitRef ("Vec" ):
856
+ msg = "Vec can only be initialized with the Vec constructor."
857
+ raise BuiltinEvalError (msg )
825
858
return tuple (cast ("T" , cast ("RuntimeExpr" , self ).__with_expr__ (x )) for x in call .args )
826
859
827
860
@method (preserve = True )
@@ -889,10 +922,11 @@ def set(self, index: i64Like, value: T) -> Vec[T]: ...
889
922
class PyObject (BuiltinExpr ):
890
923
@method (preserve = True )
891
924
def eval (self ) -> object :
892
- report = (EGraph .current or EGraph ())._run_extract (cast ("RuntimeExpr" , self ), 0 )
893
- assert isinstance (report , bindings .Best )
894
- expr = report .termdag .term_to_expr (report .term , bindings .PanicSpan ())
895
- return GLOBAL_PY_OBJECT_SORT .load (expr )
925
+ expr = cast ("RuntimeExpr" , self ).__egg_typed_expr__ .expr
926
+ if not isinstance (expr , PyObjectDecl ):
927
+ msg = "PyObject can only be evaluated if it is a PyObject literal"
928
+ raise BuiltinEvalError (msg )
929
+ return expr .value
896
930
897
931
def __init__ (self , value : object ) -> None : ...
898
932
@@ -1027,22 +1061,23 @@ def value_to_annotation(a: object) -> type | None:
1027
1061
converter (FunctionType , UnstableFn , _convert_function )
1028
1062
1029
1063
1030
- def _extract_lit (e : BaseExpr ) -> bindings . _Literal :
1064
+ def _extract_lit (e : BaseExpr ) -> LitType :
1031
1065
"""
1032
1066
Special case extracting literals to make this faster by using termdag directly.
1033
1067
"""
1034
- report = ( EGraph . current or EGraph ()). _run_extract ( cast ("RuntimeExpr" , e ), 0 )
1035
- assert isinstance (report , bindings . Best )
1036
- term = report . term
1037
- assert isinstance ( term , bindings . TermLit )
1038
- return term .value
1068
+ expr = cast ("RuntimeExpr" , e ). __egg_typed_expr__ . expr
1069
+ if not isinstance (expr , LitDecl ):
1070
+ msg = "Expected a literal"
1071
+ raise BuiltinEvalError ( msg )
1072
+ return expr .value
1039
1073
1040
1074
1041
1075
def _extract_call (e : BaseExpr ) -> CallDecl :
1042
1076
"""
1043
1077
Extracts the call form of an expression
1044
1078
"""
1045
- extracted = cast ("RuntimeExpr" , (EGraph .current or EGraph ()).extract (e ))
1046
- expr = extracted .__egg_typed_expr__ .expr
1047
- assert isinstance (expr , CallDecl )
1079
+ expr = cast ("RuntimeExpr" , e ).__egg_typed_expr__ .expr
1080
+ if not isinstance (expr , CallDecl ):
1081
+ msg = "Expected a call expression"
1082
+ raise BuiltinEvalError (msg )
1048
1083
return expr
0 commit comments