Skip to content

Commit e8be112

Browse files
Merge pull request #271 from egraphs-good/builtins
Add multiset, bigint, bigrat builtins
2 parents d282494 + 4e4f2b1 commit e8be112

File tree

12 files changed

+524
-135
lines changed

12 files changed

+524
-135
lines changed

docs/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ _This project uses semantic versioning_
2727
- Removes `eval` method from `EGraph` and moves primitive evaluation to methods on each builtin and support `int(...)` type conversions on primitives. [#265](https://github.com/egraphs-good/egglog-python/pull/265)
2828
- Change how to set global EGraph context with `with egraph.set_current()` and `EGraph.current` and add support for setting global schedule as well with `with schedule.set_current()` and `Schedule.current`. [#265](https://github.com/egraphs-good/egglog-python/pull/265)
2929
- Adds support for using `==` and `!=` directly on values instead of `eq` and `ne` functions. [#265](https://github.com/egraphs-good/egglog-python/pull/265)
30+
- Add multiset, bigint, and bigrat builtins
3031

3132
## 8.0.1 (2024-10-24)
3233

python/egglog/builtins.py

Lines changed: 284 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,15 @@
2626

2727

2828
__all__ = [
29+
"BigInt",
30+
"BigIntLike",
31+
"BigRat",
32+
"BigRatLike",
2933
"Bool",
3034
"BoolLike",
3135
"Map",
3236
"MapLike",
37+
"MultiSet",
3338
"PyObject",
3439
"Rational",
3540
"Set",
@@ -329,14 +334,14 @@ class Map(BuiltinExpr, Generic[T, V]):
329334
@method(preserve=True)
330335
def eval(self) -> dict[T, V]:
331336
call = _extract_call(self)
332-
expr = cast(RuntimeExpr, self)
337+
expr = cast("RuntimeExpr", self)
333338
d = {}
334339
while call.callable != ClassMethodRef("Map", "empty"):
335340
assert call.callable == MethodRef("Map", "insert")
336341
call_typed, k_typed, v_typed = call.args
337342
assert isinstance(call_typed.expr, CallDecl)
338-
k = cast(T, expr.__with_expr__(k_typed))
339-
v = cast(V, expr.__with_expr__(v_typed))
343+
k = cast("T", expr.__with_expr__(k_typed))
344+
v = cast("V", expr.__with_expr__(v_typed))
340345
d[k] = v
341346
call = call_typed.expr
342347
return d
@@ -397,7 +402,7 @@ class Set(BuiltinExpr, Generic[T]):
397402
def eval(self) -> set[T]:
398403
call = _extract_call(self)
399404
assert call.callable == InitRef("Set")
400-
return {cast(T, cast(RuntimeExpr, self).__with_expr__(x)) for x in call.args}
405+
return {cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args}
401406

402407
@method(preserve=True)
403408
def __iter__(self) -> Iterator[T]:
@@ -454,6 +459,53 @@ def rebuild(self) -> Set[T]: ...
454459
SetLike: TypeAlias = Set[T] | set[TO]
455460

456461

462+
class MultiSet(BuiltinExpr, Generic[T]):
463+
@method(preserve=True)
464+
def eval(self) -> list[T]:
465+
call = _extract_call(self)
466+
assert call.callable == InitRef("MultiSet")
467+
return [cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args]
468+
469+
@method(preserve=True)
470+
def __iter__(self) -> Iterator[T]:
471+
return iter(self.eval())
472+
473+
@method(preserve=True)
474+
def __len__(self) -> int:
475+
return len(self.eval())
476+
477+
@method(preserve=True)
478+
def __contains__(self, key: T) -> bool:
479+
return key in self.eval()
480+
481+
@method(egg_fn="multiset-of")
482+
def __init__(self, *args: T) -> None: ...
483+
484+
@method(egg_fn="multiset-insert")
485+
def insert(self, value: T) -> MultiSet[T]: ...
486+
487+
@method(egg_fn="multiset-not-contains")
488+
def not_contains(self, value: T) -> Unit: ...
489+
490+
@method(egg_fn="multiset-contains")
491+
def contains(self, value: T) -> Unit: ...
492+
493+
@method(egg_fn="multiset-remove")
494+
def remove(self, value: T) -> MultiSet[T]: ...
495+
496+
@method(egg_fn="multiset-length")
497+
def length(self) -> i64: ...
498+
499+
@method(egg_fn="multiset-pick")
500+
def pick(self) -> T: ...
501+
502+
@method(egg_fn="multiset-sum")
503+
def __add__(self, other: MultiSet[T]) -> MultiSet[T]: ...
504+
505+
@method(egg_fn="unstable-multiset-map", reverse_args=True)
506+
def map(self, f: Callable[[T], T]) -> MultiSet[T]: ...
507+
508+
457509
class Rational(BuiltinExpr):
458510
@method(preserve=True)
459511
def eval(self) -> Fraction:
@@ -537,14 +589,237 @@ def numer(self) -> i64: ...
537589
def denom(self) -> i64: ...
538590

539591

592+
class BigInt(BuiltinExpr):
593+
@method(preserve=True)
594+
def eval(self) -> int:
595+
call = _extract_call(self)
596+
assert call.callable == ClassMethodRef("BigInt", "from_string")
597+
(s,) = call.args
598+
assert isinstance(s.expr, LitDecl)
599+
assert isinstance(s.expr.value, str)
600+
return int(s.expr.value)
601+
602+
@method(preserve=True)
603+
def __index__(self) -> int:
604+
return self.eval()
605+
606+
@method(preserve=True)
607+
def __int__(self) -> int:
608+
return self.eval()
609+
610+
@method(egg_fn="from-string")
611+
@classmethod
612+
def from_string(cls, s: StringLike) -> BigInt: ...
613+
614+
@method(egg_fn="bigint")
615+
def __init__(self, value: i64Like) -> None: ...
616+
617+
@method(egg_fn="+")
618+
def __add__(self, other: BigIntLike) -> BigInt: ...
619+
620+
@method(egg_fn="-")
621+
def __sub__(self, other: BigIntLike) -> BigInt: ...
622+
623+
@method(egg_fn="*")
624+
def __mul__(self, other: BigIntLike) -> BigInt: ...
625+
626+
@method(egg_fn="/")
627+
def __truediv__(self, other: BigIntLike) -> BigInt: ...
628+
629+
@method(egg_fn="%")
630+
def __mod__(self, other: BigIntLike) -> BigInt: ...
631+
632+
@method(egg_fn="&")
633+
def __and__(self, other: BigIntLike) -> BigInt: ...
634+
635+
@method(egg_fn="|")
636+
def __or__(self, other: BigIntLike) -> BigInt: ...
637+
638+
@method(egg_fn="^")
639+
def __xor__(self, other: BigIntLike) -> BigInt: ...
640+
641+
@method(egg_fn="<<")
642+
def __lshift__(self, other: i64Like) -> BigInt: ...
643+
644+
@method(egg_fn=">>")
645+
def __rshift__(self, other: i64Like) -> BigInt: ...
646+
647+
def __radd__(self, other: BigIntLike) -> BigInt: ...
648+
649+
def __rsub__(self, other: BigIntLike) -> BigInt: ...
650+
651+
def __rmul__(self, other: BigIntLike) -> BigInt: ...
652+
653+
def __rtruediv__(self, other: BigIntLike) -> BigInt: ...
654+
655+
def __rmod__(self, other: BigIntLike) -> BigInt: ...
656+
657+
def __rand__(self, other: BigIntLike) -> BigInt: ...
658+
659+
def __ror__(self, other: BigIntLike) -> BigInt: ...
660+
661+
def __rxor__(self, other: BigIntLike) -> BigInt: ...
662+
663+
@method(egg_fn="not-Z")
664+
def __invert__(self) -> BigInt: ...
665+
666+
@method(egg_fn="bits")
667+
def bits(self) -> BigInt: ...
668+
669+
@method(egg_fn="<")
670+
def __lt__(self, other: BigIntLike) -> Unit: # type: ignore[empty-body,has-type]
671+
...
672+
673+
@method(egg_fn=">")
674+
def __gt__(self, other: BigIntLike) -> Unit: ...
675+
676+
@method(egg_fn="<=")
677+
def __le__(self, other: BigIntLike) -> Unit: # type: ignore[empty-body,has-type]
678+
...
679+
680+
@method(egg_fn=">=")
681+
def __ge__(self, other: BigIntLike) -> Unit: ...
682+
683+
@method(egg_fn="min")
684+
def min(self, other: BigIntLike) -> BigInt: ...
685+
686+
@method(egg_fn="max")
687+
def max(self, other: BigIntLike) -> BigInt: ...
688+
689+
@method(egg_fn="to-string")
690+
def to_string(self) -> String: ...
691+
692+
@method(egg_fn="bool-=")
693+
def bool_eq(self, other: BigIntLike) -> Bool: ...
694+
695+
@method(egg_fn="bool-<")
696+
def bool_lt(self, other: BigIntLike) -> Bool: ...
697+
698+
@method(egg_fn="bool->")
699+
def bool_gt(self, other: BigIntLike) -> Bool: ...
700+
701+
@method(egg_fn="bool-<=")
702+
def bool_le(self, other: BigIntLike) -> Bool: ...
703+
704+
@method(egg_fn="bool->=")
705+
def bool_ge(self, other: BigIntLike) -> Bool: ...
706+
707+
708+
converter(i64, BigInt, lambda i: BigInt(i))
709+
710+
BigIntLike: TypeAlias = BigInt | i64Like
711+
712+
713+
class BigRat(BuiltinExpr):
714+
@method(preserve=True)
715+
def eval(self) -> Fraction:
716+
call = _extract_call(self)
717+
assert call.callable == InitRef("BigRat")
718+
719+
def _to_fraction(e: TypedExprDecl) -> Fraction:
720+
expr = e.expr
721+
assert isinstance(expr, CallDecl)
722+
assert expr.callable == ClassMethodRef("BigInt", "from_string")
723+
(s,) = expr.args
724+
assert isinstance(s.expr, LitDecl)
725+
assert isinstance(s.expr.value, str)
726+
return Fraction(s.expr.value)
727+
728+
num, den = call.args
729+
return Fraction(_to_fraction(num), _to_fraction(den))
730+
731+
@method(preserve=True)
732+
def __float__(self) -> float:
733+
return float(self.eval())
734+
735+
@method(preserve=True)
736+
def __int__(self) -> int:
737+
return int(self.eval())
738+
739+
@method(egg_fn="bigrat")
740+
def __init__(self, num: BigIntLike, den: BigIntLike) -> None: ...
741+
742+
@method(egg_fn="to-f64")
743+
def to_f64(self) -> f64: ...
744+
745+
@method(egg_fn="+")
746+
def __add__(self, other: BigRatLike) -> BigRat: ...
747+
748+
@method(egg_fn="-")
749+
def __sub__(self, other: BigRatLike) -> BigRat: ...
750+
751+
@method(egg_fn="*")
752+
def __mul__(self, other: BigRatLike) -> BigRat: ...
753+
754+
@method(egg_fn="/")
755+
def __truediv__(self, other: BigRatLike) -> BigRat: ...
756+
757+
@method(egg_fn="min")
758+
def min(self, other: BigRatLike) -> BigRat: ...
759+
760+
@method(egg_fn="max")
761+
def max(self, other: BigRatLike) -> BigRat: ...
762+
763+
@method(egg_fn="neg")
764+
def __neg__(self) -> BigRat: ...
765+
766+
@method(egg_fn="abs")
767+
def __abs__(self) -> BigRat: ...
768+
769+
@method(egg_fn="floor")
770+
def floor(self) -> BigRat: ...
771+
772+
@method(egg_fn="ceil")
773+
def ceil(self) -> BigRat: ...
774+
775+
@method(egg_fn="round")
776+
def round(self) -> BigRat: ...
777+
778+
@method(egg_fn="pow")
779+
def __pow__(self, other: BigRatLike) -> BigRat: ...
780+
781+
@method(egg_fn="log")
782+
def log(self) -> BigRat: ...
783+
784+
@method(egg_fn="sqrt")
785+
def sqrt(self) -> BigRat: ...
786+
787+
@method(egg_fn="cbrt")
788+
def cbrt(self) -> BigRat: ...
789+
790+
@method(egg_fn="numer") # type: ignore[misc]
791+
@property
792+
def numer(self) -> BigInt: ...
793+
794+
@method(egg_fn="denom") # type: ignore[misc]
795+
@property
796+
def denom(self) -> BigInt: ...
797+
798+
@method(egg_fn="<")
799+
def __lt__(self, other: BigRatLike) -> Unit: ... # type: ignore[has-type]
800+
801+
@method(egg_fn=">")
802+
def __gt__(self, other: BigRatLike) -> Unit: ...
803+
804+
@method(egg_fn=">=")
805+
def __ge__(self, other: BigRatLike) -> Unit: ... # type: ignore[has-type]
806+
807+
@method(egg_fn="<=")
808+
def __le__(self, other: BigRatLike) -> Unit: ...
809+
810+
811+
converter(Fraction, BigRat, lambda f: BigRat(f.numerator, f.denominator))
812+
BigRatLike: TypeAlias = BigRat | Fraction
813+
814+
540815
class Vec(BuiltinExpr, Generic[T]):
541816
@method(preserve=True)
542817
def eval(self) -> tuple[T, ...]:
543818
call = _extract_call(self)
544819
if call.callable == ClassMethodRef("Vec", "empty"):
545820
return ()
546821
assert call.callable == InitRef("Vec")
547-
return tuple(cast(T, cast(RuntimeExpr, self).__with_expr__(x)) for x in call.args)
822+
return tuple(cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args)
548823

549824
@method(preserve=True)
550825
def __iter__(self) -> Iterator[T]:
@@ -611,7 +886,7 @@ def set(self, index: i64Like, value: T) -> Vec[T]: ...
611886
class PyObject(BuiltinExpr):
612887
@method(preserve=True)
613888
def eval(self) -> object:
614-
report = (EGraph.current or EGraph())._run_extract(cast(RuntimeExpr, self), 0)
889+
report = (EGraph.current or EGraph())._run_extract(cast("RuntimeExpr", self), 0)
615890
assert isinstance(report, bindings.Best)
616891
expr = report.termdag.term_to_expr(report.term, bindings.PanicSpan())
617892
return GLOBAL_PY_OBJECT_SORT.load(expr)
@@ -743,7 +1018,7 @@ def value_to_annotation(a: object) -> type | None:
7431018
# only lift runtime expressions (which could contain vars) not any other nonlocals/globals we use in the function
7441019
if not isinstance(a, RuntimeExpr):
7451020
return None
746-
return cast(type, RuntimeClass(Thunk.value(a.__egg_decls__), a.__egg_typed_expr__.tp.to_var()))
1021+
return cast("type", RuntimeClass(Thunk.value(a.__egg_decls__), a.__egg_typed_expr__.tp.to_var()))
7471022

7481023

7491024
converter(FunctionType, UnstableFn, _convert_function)
@@ -753,7 +1028,7 @@ def _extract_lit(e: BaseExpr) -> bindings._Literal:
7531028
"""
7541029
Special case extracting literals to make this faster by using termdag directly.
7551030
"""
756-
report = (EGraph.current or EGraph())._run_extract(cast(RuntimeExpr, e), 0)
1031+
report = (EGraph.current or EGraph())._run_extract(cast("RuntimeExpr", e), 0)
7571032
assert isinstance(report, bindings.Best)
7581033
term = report.term
7591034
assert isinstance(term, bindings.TermLit)
@@ -764,7 +1039,7 @@ def _extract_call(e: BaseExpr) -> CallDecl:
7641039
"""
7651040
Extracts the call form of an expression
7661041
"""
767-
extracted = cast(RuntimeExpr, (EGraph.current or EGraph()).extract(e))
1042+
extracted = cast("RuntimeExpr", (EGraph.current or EGraph()).extract(e))
7681043
expr = extracted.__egg_typed_expr__.expr
7691044
assert isinstance(expr, CallDecl)
7701045
return expr

0 commit comments

Comments
 (0)