diff --git a/docs/changelog.md b/docs/changelog.md index 7df8c05d..3f87f9d6 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,7 +4,9 @@ _This project uses semantic versioning_ ## UNRELEASED -- Support methods like on expressions [#315](https://github.com/egraphs-good/egglog-python/pull/315) +- Add ability to parse egglog expressions into Python values [#319](https://github.com/egraphs-good/egglog-python/pull/319) + - Deprecates `.eval()` method on primitives in favor of `.value` which can be used with pattern matching. +- Support methods like on expressions [#315](https://github.com/egraphs-good/egglog-python/pull/315) - Automatically Create Changelog Entry for PRs [#313](https://github.com/egraphs-good/egglog-python/pull/313) - Upgrade egglog which includes new backend. - Fixes implementation of the Python Object sort to work with objects with dupliating hashes but the same value. diff --git a/docs/reference/egglog-translation.md b/docs/reference/egglog-translation.md index d3a3d844..5976ad66 100644 --- a/docs/reference/egglog-translation.md +++ b/docs/reference/egglog-translation.md @@ -340,7 +340,7 @@ and also will make sure the variables won't be used outside of the scope of the # egg: (rewrite (Mul a b) (Mul b a)) # egg: (rewrite (Add a b) (Add b a)) -@egraph.register +@EGraph().register def _math(a: Math, b: Math): yield rewrite(a * b).to(b * a) yield rewrite(a + b).to(b + a) diff --git a/docs/reference/python-integration.md b/docs/reference/python-integration.md index cd82d0a0..c2bf4d98 100644 --- a/docs/reference/python-integration.md +++ b/docs/reference/python-integration.md @@ -6,17 +6,88 @@ file_format: mystnb Alongside [the support for builtin `egglog` functionality](./egglog-translation.md), `egglog` also provides functionality to more easily integrate with the Python ecosystem. -## Retrieving Primitive Values +## Retrieving Values -If you have a egglog primitive, you can turn it into a Python object by using `egraph.eval(...)` method: +If you have an egglog value, you might want to convert it from an expression to a native Python object. This is done through a number of helper functions: + +For a primitive value (like `i64`, `f64`, `Bool`, `String`, or `PyObject`), use `get_literal_value(expr)` or the `.value` property: ```{code-cell} python from __future__ import annotations from egglog import * -egraph = EGraph() -assert egraph.eval(i64(1) + 20) == 21 +assert get_literal_value(i64(42)) == 42 +assert get_literal_value(i64(42) + i64(1)) == None # This is because i64(42) + i64(1) is a call expression, not a literal +assert i64(42).value == 42 +assert get_literal_value(f64(3.14)) == 3.14 +assert Bool(True).value is True +assert String("hello").value == "hello" +assert PyObject([1,2,3]).value == [1,2,3] +``` + +To check if an expression is a let value and get its name, use `get_let_name(expr)`: + +```{code-cell} python +x = EGraph().let("my_var", i64(1)) +assert get_let_name(x) == "my_var" +``` + +To check if an expression is a variable and get its name, use `get_var_name(expr)`: + +```{code-cell} python +from egglog import var, get_var_name +v = var("x", i64) +assert get_var_name(v) == "x" +``` + +For a callable (method, function, classmethod, or constructor), use `get_callable_fn(expr)` to get the underlying Python function: + +```{code-cell} python +expr = i64(1) + i64(2) +fn = get_callable_fn(expr) +assert fn == i64.__add__ +``` + +To get the arguments to a callable, use `get_callable_args(expr)`. If you want to match against a specific callable, use `get_callable_args(expr, fn)`, where `fn` is the Python function you want to match against. This will return `None` if the callable does not match the function, and if it does match, the args will be properly typed: + +```{code-cell} python +assert get_callable_args(expr) == (i64(1), i64(2)) + +assert get_callable_args(expr, i64.__add__) == (i64(1), i64(2)) +assert get_callable_args(expr, i64.__sub__) == None +``` + +### Pattern Matching + +You can use Python's structural pattern matching (`match`/`case`) to destructure egglog primitives: + +```{code-cell} python +x = i64(5) +match i64(5): + case i64(i): + print(f"Integer literal: {i}") +``` + +You can add custom support for pattern matching against your classes by adding `__match_args__` to your class: + +```python +class MyExpr(Expr): + def __init__(self, value: StringLike): ... + + __match_args__ = ("value",) + + @method(preserve=True) + @property + def value(self) -> str: + match get_callable_args(self, MyExpr): + case (String(value),): + return value + raise ExprValueError(self, "MyExpr") + +match MyExpr("hello"): + case MyExpr(value): + print(f"Matched MyExpr with value: {value}") ``` ## Python Object Sort @@ -53,10 +124,10 @@ Creating hashable objects is safer, since while the rule might create new Python ### Retrieving Python Objects -Like other primitives, we can retrieve the Python object from the e-graph by using the `egraph.eval(...)` method: +Like other primitives, we can retrieve the Python object from the e-graph by using the `.value` property: ```{code-cell} python -assert egraph.eval(lst) == [1, 2, 3] +assert lst.value == [1, 2, 3] ``` ### Builtin methods @@ -66,29 +137,29 @@ Currently, we only support a few methods on `PyObject`s, but we plan to add more Conversion to/from a string: ```{code-cell} python -egraph.eval(PyObject('hi').to_string()) +EGraph().extract(PyObject('hi').to_string()) ``` ```{code-cell} python -egraph.eval(PyObject.from_string("1")) +EGraph().extract(PyObject.from_string("1")) ``` Conversion from an int: ```{code-cell} python -egraph.eval(PyObject.from_int(1)) +EGraph().extract(PyObject.from_int(1)) ``` We also support evaluating arbitrary Python code, given some locals and globals. This technically allows us to implement any Python method: ```{code-cell} python -egraph.eval(py_eval("1 + 2")) +EGraph().extract(py_eval("1 + 2")) ``` Executing Python code is also supported. In this case, the return value will be the updated globals dict, which will be copied first before using. ```{code-cell} python -egraph.eval(py_exec("x = 1 + 2")) +EGraph().extract(py_exec("x = 1 + 2")) ``` Alongside this, we support a function `dict_update` method, which can allow you to combine some local egglog expressions alongside, say, the locals and globals of the Python code you are evaluating. @@ -100,7 +171,7 @@ def my_add(a, b): amended_globals = PyObject(globals()).dict_update("one", 1) evalled = py_eval("my_add(one, 2)", locals(), amended_globals) -assert egraph.eval(evalled) == 3 +assert EGraph().extract(evalled).value == 3 ``` ### Simpler Eval @@ -116,7 +187,7 @@ def my_add(a, b): return a + b evalled = py_eval_fn(lambda a: my_add(a, 2))(1) -assert egraph.eval(evalled) == 3 +assert EGraph().extract(evalled).value == 3 ``` ## Functions @@ -263,10 +334,12 @@ class Boolean(Expr): # Run until the e-graph saturates egraph.run(10) # Extract the Python object from the e-graph - return egraph.eval(self.to_py()) - - def to_py(self) -> PyObject: - ... + value = EGraph().extract(self) + if value == TRUE: + return True + elif value == FALSE: + return False + raise ExprValueError(self, "Boolean expression must be TRUE or FALSE") def __or__(self, other: Boolean) -> Boolean: ... @@ -278,8 +351,6 @@ FALSE = egraph.constant("FALSE", Boolean) @egraph.register def _bool(x: Boolean): return [ - set_(TRUE.to_py()).to(PyObject(True)), - set_(FALSE.to_py()).to(PyObject(False)), rewrite(TRUE | x).to(TRUE), rewrite(FALSE | x).to(x), ] diff --git a/python/egglog/__init__.py b/python/egglog/__init__.py index 7e5344d3..af64575d 100644 --- a/python/egglog/__init__.py +++ b/python/egglog/__init__.py @@ -6,6 +6,7 @@ from .bindings import EggSmolError # noqa: F401 from .builtins import * # noqa: UP029 from .conversion import * +from .deconstruct import * from .egraph import * from .runtime import define_expr_method as define_expr_method # noqa: PLC0414 diff --git a/python/egglog/builtins.py b/python/egglog/builtins.py index d361d25b..9d8b9cd4 100644 --- a/python/egglog/builtins.py +++ b/python/egglog/builtins.py @@ -6,20 +6,20 @@ from __future__ import annotations from collections.abc import Callable +from dataclasses import dataclass from fractions import Fraction from functools import partial, reduce from inspect import signature from types import FunctionType, MethodType from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, cast, overload -from typing_extensions import TypeVarTuple, Unpack - -from egglog.declarations import TypedExprDecl +from typing_extensions import TypeVarTuple, Unpack, deprecated from .conversion import convert, converter, get_type_args, resolve_literal from .declarations import * +from .deconstruct import get_callable_args, get_literal_value from .egraph import BaseExpr, BuiltinExpr, _add_default_rewrite_inner, expr_fact, function, get_current_ruleset, method -from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction, resolve_type_annotation_mutate +from .runtime import RuntimeExpr, RuntimeFunction, resolve_type_annotation_mutate from .thunk import Thunk if TYPE_CHECKING: @@ -33,7 +33,7 @@ "BigRatLike", "Bool", "BoolLike", - "BuiltinEvalError", + "ExprValueError", "Map", "MapLike", "MultiSet", @@ -58,15 +58,17 @@ ] -class BuiltinEvalError(Exception): +@dataclass +class ExprValueError(AttributeError): """ - Raised when an builtin cannot be evaluated into a Python primitive because it is complex. - - Try extracting this expression first. + Raised when an expression cannot be converted to a Python value because the value is not a constructor. """ + expr: BaseExpr + allowed: str + def __str__(self) -> str: - return f"Cannot evaluate builtin expression into a Python primitive. Try extracting this expression first: {super().__str__()}" + return f"Cannot get Python value of {self.expr}, must be of form {self.allowed}. Try calling `extract` on it to get the underlying value." class Unit(BuiltinExpr, egg_sort="Unit"): @@ -82,13 +84,21 @@ def __bool__(self) -> bool: class String(BuiltinExpr): + def __init__(self, value: str) -> None: ... + @method(preserve=True) + @deprecated("use .value") def eval(self) -> str: - value = _extract_lit(self) - assert isinstance(value, str) - return value + return self.value - def __init__(self, value: str) -> None: ... + @method(preserve=True) # type: ignore[misc] + @property + def value(self) -> str: + if (value := get_literal_value(self)) is not None: + return value + raise ExprValueError(self, "String") + + __match_args__ = ("value",) @method(egg_fn="replace") def replace(self, old: StringLike, new: StringLike) -> String: ... @@ -105,17 +115,25 @@ def join(*strings: StringLike) -> String: ... class Bool(BuiltinExpr, egg_sort="bool"): + def __init__(self, value: bool) -> None: ... + @method(preserve=True) + @deprecated("use .value") def eval(self) -> bool: - value = _extract_lit(self) - assert isinstance(value, bool) - return value + return self.value + + @method(preserve=True) # type: ignore[misc] + @property + def value(self) -> bool: + if (value := get_literal_value(self)) is not None: + return value + raise ExprValueError(self, "Bool") + + __match_args__ = ("value",) @method(preserve=True) def __bool__(self) -> bool: - return self.eval() - - def __init__(self, value: bool) -> None: ... + return self.value @method(egg_fn="not") def __invert__(self) -> Bool: ... @@ -140,21 +158,29 @@ def implies(self, other: BoolLike) -> Bool: ... class i64(BuiltinExpr): # noqa: N801 + def __init__(self, value: int) -> None: ... + @method(preserve=True) + @deprecated("use .value") def eval(self) -> int: - value = _extract_lit(self) - assert isinstance(value, int) - return value + return self.value + + @method(preserve=True) # type: ignore[misc] + @property + def value(self) -> int: + if (value := get_literal_value(self)) is not None: + return value + raise ExprValueError(self, "i64") + + __match_args__ = ("value",) @method(preserve=True) def __index__(self) -> int: - return self.eval() + return self.value @method(preserve=True) def __int__(self) -> int: - return self.eval() - - def __init__(self, value: int) -> None: ... + return self.value @method(egg_fn="+") def __add__(self, other: i64Like) -> i64: ... @@ -259,21 +285,29 @@ def count_matches(s: StringLike, pattern: StringLike) -> i64: ... class f64(BuiltinExpr): # noqa: N801 + def __init__(self, value: float) -> None: ... + @method(preserve=True) + @deprecated("use .value") def eval(self) -> float: - value = _extract_lit(self) - assert isinstance(value, float) - return value + return self.value + + @method(preserve=True) # type: ignore[misc] + @property + def value(self) -> float: + if (value := get_literal_value(self)) is not None: + return value + raise ExprValueError(self, "f64") + + __match_args__ = ("value",) @method(preserve=True) def __float__(self) -> float: - return self.eval() + return self.value @method(preserve=True) def __int__(self) -> int: - return int(self.eval()) - - def __init__(self, value: float) -> None: ... + return int(self.value) @method(egg_fn="neg") def __neg__(self) -> f64: ... @@ -349,34 +383,34 @@ def to_string(self) -> String: ... class Map(BuiltinExpr, Generic[T, V]): @method(preserve=True) + @deprecated("use .value") def eval(self) -> dict[T, V]: - call = _extract_call(self) - expr = cast("RuntimeExpr", self) + return self.value + + @method(preserve=True) # type: ignore[misc] + @property + def value(self) -> dict[T, V]: d = {} - while call.callable != ClassMethodRef("Map", "empty"): - msg = "Map can only be evaluated if it is empty or a series of inserts." - if call.callable != MethodRef("Map", "insert"): - raise BuiltinEvalError(msg) - call_typed, k_typed, v_typed = call.args - if not isinstance(call_typed.expr, CallDecl): - raise BuiltinEvalError(msg) - k = cast("T", expr.__with_expr__(k_typed)) - v = cast("V", expr.__with_expr__(v_typed)) + while args := get_callable_args(self, Map[T, V].insert): + self, k, v = args # noqa: PLW0642 d[k] = v - call = call_typed.expr + if get_callable_args(self, Map.empty) is None: + raise ExprValueError(self, "Map.empty or Map.insert") return d + __match_args__ = ("value",) + @method(preserve=True) def __iter__(self) -> Iterator[T]: - return iter(self.eval()) + return iter(self.value) @method(preserve=True) def __len__(self) -> int: - return len(self.eval()) + return len(self.value) @method(preserve=True) def __contains__(self, key: T) -> bool: - return key in self.eval() + return key in self.value @method(egg_fn="map-empty") @classmethod @@ -419,24 +453,30 @@ def rebuild(self) -> Map[T, V]: ... class Set(BuiltinExpr, Generic[T]): @method(preserve=True) + @deprecated("use .value") def eval(self) -> set[T]: - call = _extract_call(self) - if call.callable != InitRef("Set"): - msg = "Set can only be initialized with the Set constructor." - raise BuiltinEvalError(msg) - return {cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args} + return self.value + + @method(preserve=True) # type: ignore[misc] + @property + def value(self) -> set[T]: + if (args := get_callable_args(self, Set[T])) is not None: + return set(args) + raise ExprValueError(self, "Set(*xs)") + + __match_args__ = ("value",) @method(preserve=True) def __iter__(self) -> Iterator[T]: - return iter(self.eval()) + return iter(self.value) @method(preserve=True) def __len__(self) -> int: - return len(self.eval()) + return len(self.value) @method(preserve=True) def __contains__(self, key: T) -> bool: - return key in self.eval() + return key in self.value @method(egg_fn="set-of") def __init__(self, *args: T) -> None: ... @@ -483,24 +523,30 @@ def rebuild(self) -> Set[T]: ... class MultiSet(BuiltinExpr, Generic[T]): @method(preserve=True) + @deprecated("use .value") def eval(self) -> list[T]: - call = _extract_call(self) - if call.callable != InitRef("MultiSet"): - msg = "MultiSet can only be initialized with the MultiSet constructor." - raise BuiltinEvalError(msg) - return [cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args] + return self.value + + @method(preserve=True) # type: ignore[misc] + @property + def value(self) -> list[T]: + if (args := get_callable_args(self, MultiSet[T])) is not None: + return list(args) + raise ExprValueError(self, "MultiSet") + + __match_args__ = ("value",) @method(preserve=True) def __iter__(self) -> Iterator[T]: - return iter(self.eval()) + return iter(self.value) @method(preserve=True) def __len__(self) -> int: - return len(self.eval()) + return len(self.value) @method(preserve=True) def __contains__(self, key: T) -> bool: - return key in self.eval() + return key in self.value @method(egg_fn="multiset-of") def __init__(self, *args: T) -> None: ... @@ -532,30 +578,27 @@ def map(self, f: Callable[[T], T]) -> MultiSet[T]: ... class Rational(BuiltinExpr): @method(preserve=True) + @deprecated("use .value") def eval(self) -> Fraction: - call = _extract_call(self) - if call.callable != InitRef("Rational"): - msg = "Rational can only be initialized with the Rational constructor." - raise BuiltinEvalError(msg) - - def _to_int(e: TypedExprDecl) -> int: - expr = e.expr - if not isinstance(expr, LitDecl): - msg = "Rational can only be initialized with literals" - raise BuiltinEvalError(msg) - assert isinstance(expr.value, int) - return expr.value - - num, den = call.args - return Fraction(_to_int(num), _to_int(den)) + return self.value + + @method(preserve=True) # type: ignore[misc] + @property + def value(self) -> Fraction: + match get_callable_args(self, Rational): + case (i64(num), i64(den)): + return Fraction(num, den) + raise ExprValueError(self, "Rational(i64(num), i64(den))") + + __match_args__ = ("value",) @method(preserve=True) def __float__(self) -> float: - return float(self.eval()) + return float(self.value) @method(preserve=True) def __int__(self) -> int: - return int(self.eval()) + return int(self.value) @method(egg_fn="rational") def __init__(self, num: i64Like, den: i64Like) -> None: ... @@ -619,25 +662,27 @@ def denom(self) -> i64: ... class BigInt(BuiltinExpr): @method(preserve=True) + @deprecated("use .value") def eval(self) -> int: - call = _extract_call(self) - if call.callable != ClassMethodRef("BigInt", "from_string"): - msg = "BigInt can only be initialized with the BigInt constructor." - raise BuiltinEvalError(msg) - (s,) = call.args - if not isinstance(s.expr, LitDecl): - msg = "BigInt can only be initialized with literals" - raise BuiltinEvalError(msg) - assert isinstance(s.expr.value, str) - return int(s.expr.value) + return self.value + + @method(preserve=True) # type: ignore[misc] + @property + def value(self) -> int: + match get_callable_args(self, BigInt.from_string): + case (String(s),): + return int(s) + raise ExprValueError(self, "BigInt.from_string(String(s))") + + __match_args__ = ("value",) @method(preserve=True) def __index__(self) -> int: - return self.eval() + return self.value @method(preserve=True) def __int__(self) -> int: - return self.eval() + return self.value @method(egg_fn="from-string") @classmethod @@ -744,34 +789,27 @@ def bool_ge(self, other: BigIntLike) -> Bool: ... class BigRat(BuiltinExpr): @method(preserve=True) + @deprecated("use .value") def eval(self) -> Fraction: - call = _extract_call(self) - if call.callable != InitRef("BigRat"): - msg = "BigRat can only be initialized with the BigRat constructor." - raise BuiltinEvalError(msg) - - def _to_fraction(e: TypedExprDecl) -> Fraction: - expr = e.expr - if not isinstance(expr, CallDecl) or expr.callable != ClassMethodRef("BigInt", "from_string"): - msg = "BigRat can only be initialized BigInt strings" - raise BuiltinEvalError(msg) - (s,) = expr.args - if not isinstance(s.expr, LitDecl): - msg = "BigInt can only be initialized with literals" - raise BuiltinEvalError(msg) - assert isinstance(s.expr.value, str) - return Fraction(s.expr.value) - - num, den = call.args - return Fraction(_to_fraction(num), _to_fraction(den)) + return self.value + + @method(preserve=True) # type: ignore[misc] + @property + def value(self) -> Fraction: + match get_callable_args(self, BigRat): + case (BigInt(num), BigInt(den)): + return Fraction(num, den) + raise ExprValueError(self, "BigRat(BigInt(num), BigInt(den))") + + __match_args__ = ("value",) @method(preserve=True) def __float__(self) -> float: - return float(self.eval()) + return float(self.value) @method(preserve=True) def __int__(self) -> int: - return int(self.eval()) + return int(self.value) @method(egg_fn="bigrat") def __init__(self, num: BigIntLike, den: BigIntLike) -> None: ... @@ -851,27 +889,32 @@ def __le__(self, other: BigRatLike) -> Unit: ... class Vec(BuiltinExpr, Generic[T]): @method(preserve=True) + @deprecated("use .value") def eval(self) -> tuple[T, ...]: - call = _extract_call(self) - if call.callable == ClassMethodRef("Vec", "empty"): + return self.value + + @method(preserve=True) # type: ignore[misc] + @property + def value(self) -> tuple[T, ...]: + if get_callable_args(self, Vec.empty) is not None: return () + if (args := get_callable_args(self, Vec[T])) is not None: + return args + raise ExprValueError(self, "Vec(*xs) or Vec.empty()") - if call.callable != InitRef("Vec"): - msg = "Vec can only be initialized with the Vec constructor." - raise BuiltinEvalError(msg) - return tuple(cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args) + __match_args__ = ("value",) @method(preserve=True) def __iter__(self) -> Iterator[T]: - return iter(self.eval()) + return iter(self.value) @method(preserve=True) def __len__(self) -> int: - return len(self.eval()) + return len(self.value) @method(preserve=True) def __contains__(self, key: T) -> bool: - return key in self.eval() + return key in self.value @method(egg_fn="vec-of") def __init__(self, *args: T) -> None: ... @@ -925,13 +968,20 @@ def set(self, index: i64Like, value: T) -> Vec[T]: ... class PyObject(BuiltinExpr): @method(preserve=True) + @deprecated("use .value") def eval(self) -> object: + return self.value + + @method(preserve=True) # type: ignore[misc] + @property + def value(self) -> object: expr = cast("RuntimeExpr", self).__egg_typed_expr__.expr if not isinstance(expr, PyObjectDecl): - msg = "PyObject can only be evaluated if it is a PyObject literal" - raise BuiltinEvalError(msg) + raise ExprValueError(self, "PyObject(x)") return expr.value + __match_args__ = ("value",) + def __init__(self, value: object) -> None: ... @method(egg_fn="py-from-string") @@ -1021,46 +1071,25 @@ def __init__(self, f: Callable[[T1, T2, Unpack[TS]], T], _a: T1, _b: T2, /) -> N @method(egg_fn="unstable-fn") def __init__(self, f, *partial) -> None: ... - @method(egg_fn="unstable-app") - def __call__(self, *args: Unpack[TS]) -> T: ... - @method(preserve=True) + @deprecated("use .value") def eval(self) -> Callable[[Unpack[TS]], T]: + return self.value + + @method(preserve=True) # type: ignore[prop-decorator] + @property + def value(self) -> Callable[[Unpack[TS]], T]: """ If this is a constructor, returns either the callable directly or a `functools.partial` function if args are provided. """ - assert isinstance(self, RuntimeExpr) - match self.__egg_typed_expr__.expr: - case PartialCallDecl(CallDecl() as call): - fn, args = _deconstruct_call_decl(self.__egg_decls_thunk__, call) - if not args: - return fn - return partial(fn, *args) - msg = "UnstableFn can only be evaluated if it is a function or a partial application of a function." - raise BuiltinEvalError(msg) - - -def _deconstruct_call_decl( - decls_thunk: Callable[[], Declarations], call: CallDecl -) -> tuple[Callable, tuple[object, ...]]: - """ - Deconstructs a CallDecl into a runtime callable and its arguments. - """ - args = call.args - arg_exprs = tuple(RuntimeExpr(decls_thunk, Thunk.value(a)) for a in args) - egg_bound = ( - JustTypeRef(call.callable.class_name, call.bound_tp_params or ()) - if isinstance(call.callable, (ClassMethodRef, InitRef, ClassVariableRef)) - else None - ) - if isinstance(call.callable, InitRef): - return RuntimeClass( - decls_thunk, - TypeRefWithVars( - call.callable.class_name, - ), - ), arg_exprs - return RuntimeFunction(decls_thunk, Thunk.value(call.callable), egg_bound), arg_exprs + if (fn := get_literal_value(self)) is not None: + return fn + raise ExprValueError(self, "UnstableFn(f, *args)") + + __match_args__ = ("value",) + + @method(egg_fn="unstable-app") + def __call__(self, *args: Unpack[TS]) -> T: ... # Method Type is for builtins like __getitem__ @@ -1102,29 +1131,3 @@ def _convert_function(fn: FunctionType) -> UnstableFn: converter(FunctionType, UnstableFn, _convert_function) - -## -# Utility Functions -## - - -def _extract_lit(e: BaseExpr) -> LitType: - """ - Special case extracting literals to make this faster by using termdag directly. - """ - expr = cast("RuntimeExpr", e).__egg_typed_expr__.expr - if not isinstance(expr, LitDecl): - msg = "Expected a literal" - raise BuiltinEvalError(msg) - return expr.value - - -def _extract_call(e: BaseExpr) -> CallDecl: - """ - Extracts the call form of an expression - """ - expr = cast("RuntimeExpr", e).__egg_typed_expr__.expr - if not isinstance(expr, CallDecl): - msg = "Expected a call expression" - raise BuiltinEvalError(msg) - return expr diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index 24512704..b1c1368f 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -84,7 +84,7 @@ ] -@dataclass +@dataclass(match_args=False) class DelayedDeclerations: __egg_decls_thunk__: Callable[[], Declarations] = field(repr=False) @@ -286,6 +286,7 @@ class ClassDecl: methods: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict) properties: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict) preserved_methods: dict[str, Callable] = field(default_factory=dict) + match_args: tuple[str, ...] = field(default=()) @dataclass(frozen=True) diff --git a/python/egglog/deconstruct.py b/python/egglog/deconstruct.py new file mode 100644 index 00000000..aab80705 --- /dev/null +++ b/python/egglog/deconstruct.py @@ -0,0 +1,173 @@ +""" +Utility functions to deconstruct expressions in Python. +""" + +from __future__ import annotations + +from collections.abc import Callable +from functools import partial +from typing import TYPE_CHECKING, TypeVar, overload + +from typing_extensions import TypeVarTuple, Unpack + +from .declarations import * +from .egraph import BaseExpr +from .runtime import * +from .thunk import * + +if TYPE_CHECKING: + from .builtins import Bool, PyObject, String, UnstableFn, f64, i64 + + +T = TypeVar("T", bound=BaseExpr) +TS = TypeVarTuple("TS", default=Unpack[tuple[BaseExpr, ...]]) + +__all__ = ["get_callable_args", "get_callable_fn", "get_let_name", "get_literal_value", "get_var_name"] + + +@overload +def get_literal_value(x: String) -> str | None: ... + + +@overload +def get_literal_value(x: Bool) -> bool | None: ... + + +@overload +def get_literal_value(x: i64) -> int | None: ... + + +@overload +def get_literal_value(x: f64) -> float | None: ... + + +@overload +def get_literal_value(x: PyObject) -> object: ... + + +@overload +def get_literal_value(x: UnstableFn[T, Unpack[TS]]) -> Callable[[Unpack[TS]], T] | None: ... + + +def get_literal_value(x: String | Bool | i64 | f64 | PyObject | UnstableFn) -> object: + """ + Returns the literal value of an expression if it is a literal. + If it is not a literal, returns None. + """ + if not isinstance(x, RuntimeExpr): + raise TypeError(f"Expected Expression, got {type(x).__name__}") + match x.__egg_typed_expr__.expr: + case LitDecl(v): + return v + case PyObjectDecl(obj): + return obj + case PartialCallDecl(call): + fn, args = _deconstruct_call_decl(x.__egg_decls_thunk__, call) + if not args: + return fn + return partial(fn, *args) + return None + + +def get_let_name(x: BaseExpr) -> str | None: + """ + Check if the expression is a `let` expression and return the name of the variable. + If it is not a `let` expression, return None. + """ + if not isinstance(x, RuntimeExpr): + raise TypeError(f"Expected Expression, got {type(x).__name__}") + match x.__egg_typed_expr__.expr: + case LetRefDecl(name): + return name + return None + + +def get_var_name(x: BaseExpr) -> str | None: + """ + Check if the expression is a variable and return its name. + If it is not a variable, return None. + """ + if not isinstance(x, RuntimeExpr): + raise TypeError(f"Expected Expression, got {type(x).__name__}") + match x.__egg_typed_expr__.expr: + case UnboundVarDecl(name, _egg_name): + return name + return None + + +def get_callable_fn(x: T) -> Callable[..., T] | None: + """ + Gets the function of an expression if it is a call expression. + If it is not a call expression (a property, a primitive value, constants, classvars, a let value), return None. + For those values, you can check them by comparing them directly with equality or for primitives calling `.eval()` + to return the Python value. + """ + if not isinstance(x, RuntimeExpr): + raise TypeError(f"Expected Expression, got {type(x).__name__}") + match x.__egg_typed_expr__.expr: + case CallDecl() as call: + fn, _ = _deconstruct_call_decl(x.__egg_decls_thunk__, call) + return fn + return None + + +@overload +def get_callable_args(x: T, fn: None = ...) -> tuple[BaseExpr, ...]: ... + + +@overload +def get_callable_args(x: T, fn: Callable[[Unpack[TS]], T]) -> tuple[Unpack[TS]] | None: ... + + +def get_callable_args(x: T, fn: Callable[[Unpack[TS]], T] | None = None) -> tuple[Unpack[TS]] | None: + """ + Gets all the arguments of an expression. + If a function is provided, it will only return the arguments if the expression is a call + to that function. + + Note that recursively calling the arguments is the safe way to walk the expression tree. + """ + if not isinstance(x, RuntimeExpr): + raise TypeError(f"Expected Expression, got {type(x).__name__}") + match x.__egg_typed_expr__.expr: + case CallDecl() as call: + actual_fn, args = _deconstruct_call_decl(x.__egg_decls_thunk__, call) + if fn is None: + return args + # Compare functions and classes without considering bound type parameters, so that you can pass + # in a binding like Vec[i64] and match Vec[i64](...) or Vec(...) calls. + if ( + isinstance(actual_fn, RuntimeFunction) + and isinstance(fn, RuntimeFunction) + and actual_fn.__egg_ref__ == fn.__egg_ref__ + ): + return args + if ( + isinstance(actual_fn, RuntimeClass) + and isinstance(fn, RuntimeClass) + and actual_fn.__egg_tp__.name == fn.__egg_tp__.name + ): + return args + return None + + +def _deconstruct_call_decl( + decls_thunk: Callable[[], Declarations], call: CallDecl +) -> tuple[Callable, tuple[object, ...]]: + """ + Deconstructs a CallDecl into a runtime callable and its arguments. + """ + args = call.args + arg_exprs = tuple(RuntimeExpr(decls_thunk, Thunk.value(a)) for a in args) + if isinstance(call.callable, InitRef): + return RuntimeClass( + decls_thunk, + TypeRefWithVars(call.callable.class_name, tuple(tp.to_var() for tp in (call.bound_tp_params or []))), + ), arg_exprs + egg_bound = ( + JustTypeRef(call.callable.class_name, call.bound_tp_params or ()) + if isinstance(call.callable, ClassMethodRef) + else None + ) + + return RuntimeFunction(decls_thunk, Thunk.value(call.callable), egg_bound), arg_exprs diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index f973ac31..3869bc31 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -5,7 +5,7 @@ import pathlib import tempfile from collections.abc import Callable, Generator, Iterable -from contextvars import ContextVar +from contextvars import ContextVar, Token from dataclasses import InitVar, dataclass, field from functools import partial from inspect import Parameter, currentframe, signature @@ -114,6 +114,7 @@ "__qualname__", "__firstlineno__", "__static_attributes__", + "__match_args__", # Ignore all reflected binary method *(f"__r{m[2:]}" for m in NUMERIC_BINARY_METHODS), } @@ -361,6 +362,7 @@ class BaseExpr(metaclass=_ExprMetaclass): def __ne__(self, other: Self) -> Unit: ... # type: ignore[override, empty-body] + # not currently dissalowing other types of equality https://github.com/python/typeshed/issues/8217#issuecomment-3140873292 def __eq__(self, other: Self) -> Fact: ... # type: ignore[override, empty-body] @@ -394,7 +396,7 @@ def _generate_class_decls( # noqa: C901,PLR0912 ) type_vars = tuple(ClassTypeVarRef.from_type_var(p) for p in parameters) del parameters - cls_decl = ClassDecl(egg_sort, type_vars, builtin) + cls_decl = ClassDecl(egg_sort, type_vars, builtin, match_args=namespace.pop("__match_args__", ())) decls = Declarations(_classes={cls_name: cls_decl}) # Update class think eagerly when resolving so that lookups work in methods runtime_cls.__egg_decls_thunk__ = Thunk.value(decls) @@ -446,6 +448,9 @@ def _generate_class_decls( # noqa: C901,PLR0912 continue locals = frame.f_locals ref: ClassMethodRef | MethodRef | PropertyRef | InitRef + # TODO: Store deprecated message so we can print at runtime + if (getattr(fn, "__deprecated__", None)) is not None: + fn = fn.__wrapped__ # type: ignore[attr-defined] match fn: case classmethod(): ref = ClassMethodRef(cls_name, method_name) @@ -854,6 +859,7 @@ def _callable_to_egg(self, fn: object) -> str: self._add_decls(decls) return self._state.callable_ref_to_egg(ref)[0] + # TODO: Change let to be action... def let(self, name: str, expr: BASE_EXPR) -> BASE_EXPR: """ Define a new expression in the egraph and return a reference to it. @@ -1501,16 +1507,17 @@ def rule(*facts: FactLike, ruleset: None = None, name: str | None = None) -> _Ru return _RuleBuilder(facts=_fact_likes(facts), name=name, ruleset=ruleset) -def var(name: str, bound: type[T]) -> T: +def var(name: str, bound: type[T], egg_name: str | None = None) -> T: """Create a new variable with the given name and type.""" - return cast("T", _var(name, bound)) + return cast("T", _var(name, bound, egg_name=egg_name)) -def _var(name: str, bound: object) -> RuntimeExpr: +def _var(name: str, bound: object, egg_name: str | None) -> RuntimeExpr: """Create a new variable with the given name and type.""" decls_like, type_ref = resolve_type_annotation(bound) return RuntimeExpr( - Thunk.fn(Declarations.create, decls_like), Thunk.value(TypedExprDecl(type_ref.to_just(), UnboundVarDecl(name))) + Thunk.fn(Declarations.create, decls_like), + Thunk.value(TypedExprDecl(type_ref.to_just(), UnboundVarDecl(name, egg_name))), ) @@ -1760,7 +1767,7 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) -> # python/tests/test_no_import_star.py::test_no_import_star_rulesset combined = {**gen.__globals__, **frame.f_locals} hints = get_type_hints(gen, combined, combined) - args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()] + args = [_var(p.name, hints[p.name], egg_name=None) for p in signature(gen).parameters.values()] return list(gen(*args)) # type: ignore[misc] @@ -1786,7 +1793,7 @@ def get_current_ruleset() -> Ruleset | None: @contextlib.contextmanager def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]: - token = _CURRENT_RULESET.set(r) + token: Token[Ruleset | None] = _CURRENT_RULESET.set(r) try: yield finally: diff --git a/python/egglog/examples/bignum.py b/python/egglog/examples/bignum.py index d20e104f..a17d997d 100644 --- a/python/egglog/examples/bignum.py +++ b/python/egglog/examples/bignum.py @@ -14,7 +14,7 @@ egraph = EGraph() -assert egraph.extract(z.numer.to_string()).eval() == "-617" +assert egraph.extract(z.numer.to_string()).value == "-617" @function diff --git a/python/egglog/examples/multiset.py b/python/egglog/examples/multiset.py index c0e4b2b3..8306dadc 100644 --- a/python/egglog/examples/multiset.py +++ b/python/egglog/examples/multiset.py @@ -32,7 +32,7 @@ def math_ruleset(i: i64): egraph.check(xs == MultiSet(Math(1), Math(3), Math(2))) egraph.check_fail(xs == MultiSet(Math(1), Math(1), Math(2), Math(3))) -assert Counter(egraph.extract(xs).eval()) == Counter({Math(1): 1, Math(2): 1, Math(3): 1}) +assert Counter(egraph.extract(xs).value) == Counter({Math(1): 1, Math(2): 1, Math(3): 1}) inserted = MultiSet(Math(1), Math(2), Math(3), Math(4)) @@ -45,7 +45,7 @@ def math_ruleset(i: i64): egraph.check(xs.remove(Math(1)) == MultiSet(Math(2), Math(3))) -assert egraph.extract(xs.length()).eval() == 3 +assert egraph.extract(xs.length()).value == 3 assert len(xs) == 3 egraph.check(MultiSet(Math(1), Math(1)).length() == i64(2)) diff --git a/python/egglog/exp/array_api.py b/python/egglog/exp/array_api.py index 1a75724f..3e791835 100644 --- a/python/egglog/exp/array_api.py +++ b/python/egglog/exp/array_api.py @@ -1982,4 +1982,4 @@ def try_evaling(egraph: EGraph, schedule: Schedule, expr: Expr, prim_expr: Built except BaseException as e: # egraph.display(n_inline_leaves=1, split_primitive_outputs=True) raise add_note(f"Cannot evaluate {egraph.extract(expr)}", e) # noqa: B904 - return extracted.eval() # type: ignore[attr-defined] + return extracted.value # type: ignore[attr-defined] diff --git a/python/egglog/runtime.py b/python/egglog/runtime.py index f43d51f1..02c7c2dd 100644 --- a/python/egglog/runtime.py +++ b/python/egglog/runtime.py @@ -13,8 +13,9 @@ import itertools import operator +import types from collections.abc import Callable -from dataclasses import dataclass, replace +from dataclasses import InitVar, dataclass, replace from inspect import Parameter, Signature from itertools import zip_longest from typing import TYPE_CHECKING, Any, TypeVar, Union, cast, get_args, get_origin @@ -150,11 +151,48 @@ def inverse_resolve_type_annotation(decls_thunk: Callable[[], Declarations], tp: ## -@dataclass -class RuntimeClass(DelayedDeclerations): +class BaseClassFactoryMeta(type): + """ + Base metaclass for all runtime classes created by ClassFactory + """ + + def __instancecheck__(cls, instance: object) -> bool: + assert isinstance(cls, RuntimeClass) + return isinstance(instance, RuntimeExpr) and cls.__egg_tp__.name == instance.__egg_typed_expr__.tp.name + + +class ClassFactory(type): + """ + A metaclass for types which should create `type` objects when instantiated. + + That's so that they work with `isinstance` and can be placed in `match ClassName()`. + """ + + def __call__(cls, *args, **kwargs) -> type: + # If we have params, don't inherit from `type` because we don't need to match against this and also + # this won't work with `Union[X]` because it won't look at `__parameters__` for instances of `type`. + if kwargs.pop("_egg_has_params", False): + return super().__call__(*args, **kwargs) + namespace: dict[str, Any] = {} + for m in reversed(cls.__mro__): + namespace.update(m.__dict__) + init = namespace.pop("__init__") + meta = types.new_class("type(RuntimeClass)", (BaseClassFactoryMeta,), {}, lambda ns: ns.update(**namespace)) + tp = types.new_class("RuntimeClass", (), {"metaclass": meta}) + init(tp, *args, **kwargs) + return tp + + def __instancecheck__(cls, instance: object) -> bool: + return isinstance(instance, BaseClassFactoryMeta) + + +@dataclass(match_args=False) +class RuntimeClass(DelayedDeclerations, metaclass=ClassFactory): __egg_tp__: TypeRefWithVars + # True if we want `__parameters__` to be recognized by `Union`, which means we can't inherit from `type` directly. + _egg_has_params: InitVar[bool] = False - def __post_init__(self) -> None: + def __post_init__(self, _egg_has_params: bool) -> None: global _PY_OBJECT_CLASS, _UNSTABLE_FN_CLASS if (name := self.__egg_tp__.name) == "PyObject": _PY_OBJECT_CLASS = self @@ -244,7 +282,7 @@ def __getitem__(self, args: object) -> RuntimeClass: else: final_args = new_args tp = TypeRefWithVars(self.__egg_tp__.name, final_args) - return RuntimeClass(Thunk.fn(Declarations.create, self, *decls_like), tp) + return RuntimeClass(Thunk.fn(Declarations.create, self, *decls_like), tp, _egg_has_params=True) def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable: if name == "__origin__" and self.__egg_tp__.args: @@ -322,6 +360,10 @@ def __parameters__(self) -> tuple[object, ...]: """ return tuple(inverse_resolve_type_annotation(self.__egg_decls_thunk__, tp) for tp in self.__egg_tp__.args) + @property + def __match_args__(self) -> tuple[str, ...]: + return self.__egg_decls__._classes[self.__egg_tp__.name].match_args + @dataclass class RuntimeFunction(DelayedDeclerations): @@ -583,17 +625,17 @@ def _defined_method(self: RuntimeExpr, *args, __name: str = name, **kwargs): define_expr_method(name) -for name, reversed in itertools.product(NUMERIC_BINARY_METHODS, (False, True)): +for name, r_method in itertools.product(NUMERIC_BINARY_METHODS, (False, True)): - def _numeric_binary_method(self: object, other: object, name: str = name, reversed: bool = reversed) -> object: + def _numeric_binary_method(self: object, other: object, name: str = name, r_method: bool = r_method) -> object: """ Implements numeric binary operations. Tries to find the minimum cost conversion of either the LHS or the RHS, by finding all methods with either the LHS or the RHS as exactly the right type and then upcasting the other to that type. """ - # 1. switch if reversed - if reversed: + # 1. switch if reversed method + if r_method: self, other = other, self # If the types don't exactly match to start, then we need to try converting one of them, by finding the cheapest conversion if not ( @@ -646,7 +688,7 @@ def _numeric_binary_method(self: object, other: object, name: str = name, revers fn = RuntimeFunction(Thunk.value(self.__egg_decls__), Thunk.value(method_ref), self) return fn(other) - method_name = f"__r{name[2:]}" if reversed else name + method_name = f"__r{name[2:]}" if r_method else name setattr(RuntimeExpr, method_name, _numeric_binary_method) diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index ce9285a2..2020a6e5 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -323,8 +323,8 @@ def test_program_compile(program: Program, snapshot_py): egraph = EGraph() egraph.register(simplified_program.compile()) egraph.run(array_api_program_gen_schedule) - statements = egraph.extract(simplified_program.statements).eval() - expr = egraph.extract(simplified_program.expr).eval() + statements = egraph.extract(simplified_program.statements).value + expr = egraph.extract(simplified_program.expr).value assert "\n".join([*statements.split("\n"), expr]) == snapshot_py(name="code") diff --git a/python/tests/test_deconstruct.py b/python/tests/test_deconstruct.py new file mode 100644 index 00000000..f253142a --- /dev/null +++ b/python/tests/test_deconstruct.py @@ -0,0 +1,96 @@ +# mypy: disable-error-code="empty-body" +from __future__ import annotations + +from functools import partial +from typing import ClassVar + +import pytest + +from egglog import * + + +class X(Expr): + v: ClassVar[X] + + def __init__(self) -> None: ... + + @property + def p(self) -> X: ... + + @classmethod + def c(cls) -> X: ... + + def m(self, a: X) -> X: ... + + +@function +def f(x: X) -> X: ... + + +@function +def y(x: X, i: i64) -> X: ... + + +c = constant("c", X) + +v = var("v", X) +l = EGraph().let("l", X()) + + +@pytest.mark.parametrize( + ("expr", "value"), + [ + (i64(42), 42), + (i64(42) + i64(1), None), + (f64(3.14), 3.14), + (Bool(True), True), + (PyObject("test"), "test"), + (UnstableFn(f), f), + (UnstableFn(f, X()), partial(f, X())), + ], +) +def test_get_literal_value(expr, value): + res = get_literal_value(expr) + if isinstance(res, partial) and isinstance(value, partial): + assert res.func == value.func + assert res.args == value.args + assert res.keywords == value.keywords + else: + assert res == value + + +def test_get_let_name(): + assert get_let_name(l) == "l" + assert get_let_name(X()) is None + + +def test_get_var_name(): + assert get_var_name(v) == "v" + assert get_var_name(X()) is None + + +@pytest.mark.parametrize( + ("expr", "fn", "args"), + [ + pytest.param(f(X()), f, (X(),), id="function call"), + pytest.param(X().p, X.p, (X(),), id="property"), + pytest.param(X.c(), X.c, (), id="classmethod"), + pytest.param(X(), X, (), id="init"), + pytest.param(X().m(X()), X.m, (X(), X()), id="method call"), + pytest.param(Vec(i64(1)), Vec, (i64(1),), id="generic class"), + pytest.param(Vec[i64](), Vec[i64], (), id="generic parameter init"), + pytest.param(Vec[i64].empty(), Vec[i64].empty, (), id="generic parameter classmethod"), + ], +) +def test_callable(expr, fn, args): + assert get_callable_fn(expr) == fn + assert get_callable_args(expr) == args + assert get_callable_args(expr, fn) == args + + +def test_callable_generic_applied(): + assert get_callable_args(Vec(i64(1)), Vec[i64]) == (i64(1),) + + +def test_callable_generic_applied_method(): + assert get_callable_args(Vec[i64].empty(), Vec[i64].empty) == () diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 2165f832..7618e3f0 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -228,30 +228,30 @@ def foo(x: i64Like, y: i64Like = i64(1)) -> i64: ... class TestPyObject: def test_from_string(self): - assert EGraph().extract(PyObject.from_string("foo")).eval() == "foo" + assert EGraph().extract(PyObject.from_string("foo")).value == "foo" def test_to_string(self): EGraph().check(PyObject("foo").to_string() == String("foo")) def test_dict_update(self): original_d = {"foo": "bar"} - res = EGraph().extract(PyObject(original_d).dict_update("foo", "baz")).eval() + res = EGraph().extract(PyObject(original_d).dict_update("foo", "baz")).value assert res == {"foo": "baz"} assert original_d == {"foo": "bar"} def test_eval(self): - assert EGraph().extract(py_eval("x + y", {"x": 10, "y": 20}, {})).eval() == 30 + assert EGraph().extract(py_eval("x + y", {"x": 10, "y": 20}, {})).value == 30 def test_eval_local(self): x = "hi" res = py_eval("my_add(x, y)", PyObject(locals()).dict_update("y", "there"), globals()) - assert EGraph().extract(res).eval() == "hithere" + assert EGraph().extract(res).value == "hithere" def test_exec(self): - assert EGraph().extract(py_exec("x = 10")).eval() == {"x": 10} + assert EGraph().extract(py_exec("x = 10")).value == {"x": 10} def test_exec_globals(self): - assert EGraph().extract(py_exec("x = y + 1", {"y": 10})).eval() == {"x": 11} + assert EGraph().extract(py_exec("x = y + 1", {"y": 10})).value == {"x": 11} def my_add(a, b): @@ -431,27 +431,26 @@ def from_int(cls, other: Int) -> NDArray: ... class TestEval: def test_string(self): - assert String("hi").eval() == "hi" + assert String("hi").value == "hi" def test_bool(self): - assert Bool(True).eval() is True + assert Bool(True).value is True assert bool(Bool(True)) is True def test_i64(self): - assert i64(10).eval() == 10 + assert i64(10).value == 10 assert int(i64(10)) == 10 assert [10][i64(0)] == 10 def test_f64(self): - assert f64(10.0).eval() == 10.0 + assert f64(10.0).value == 10.0 assert int(f64(10.0)) == 10 assert float(f64(10.0)) == 10.0 def test_map(self): - assert Map[String, i64].empty().eval() == {} + assert Map[String, i64].empty().value == {} m = Map[String, i64].empty().insert(String("a"), i64(1)).insert(String("b"), i64(2)) - # TODO: Add __eq__ with eq() that evals to True on boolean comparison? And same with ne? - assert m.eval() == {String("a"): i64(1), String("b"): i64(2)} + assert m.value == {String("a"): i64(1), String("b"): i64(2)} assert set(m) == {String("a"), String("b")} assert len(m) == 2 @@ -459,9 +458,9 @@ def test_map(self): assert String("c") not in m def test_set(self): - assert EGraph().extract(Set[i64].empty()).eval() == set() + assert EGraph().extract(Set[i64].empty()).value == set() s = Set(i64(1), i64(2)) - assert s.eval() == {i64(1), i64(2)} + assert s.value == {i64(1), i64(2)} assert set(s) == {i64(1), i64(2)} assert len(s) == 2 @@ -469,14 +468,14 @@ def test_set(self): assert i64(3) not in s def test_rational(self): - assert Rational(1, 2).eval() == Fraction(1, 2) + assert Rational(1, 2).value == Fraction(1, 2) assert float(Rational(1, 2)) == 0.5 assert int(Rational(1, 1)) == 1 def test_vec(self): - assert Vec[i64].empty().eval() == () + assert Vec[i64].empty().value == () s = Vec(i64(1), i64(2)) - assert s.eval() == (i64(1), i64(2)) + assert s.value == (i64(1), i64(2)) assert list(s) == [i64(1), i64(2)] assert len(s) == 2 @@ -484,9 +483,9 @@ def test_vec(self): assert i64(3) not in s def test_py_object(self): - assert PyObject(10).eval() == 10 + assert PyObject(10).value == 10 o = object() - assert PyObject(o).eval() is o + assert PyObject(o).value is o def test_big_int(self): assert int(EGraph().extract(BigInt(10))) == 10 @@ -494,7 +493,7 @@ def test_big_int(self): def test_big_rat(self): br = EGraph().extract(BigRat(1, 2)) assert float(br) == 1 / 2 - assert br.eval() == Fraction(1, 2) + assert br.value == Fraction(1, 2) def test_multiset(self): assert list(MultiSet(i64(1), i64(1))) == [i64(1), i64(1)] @@ -507,9 +506,9 @@ def __init__(self) -> None: ... def f(x: Math) -> Math: ... u_f = UnstableFn(f) - assert u_f.eval() == f + assert u_f.value == f p_u_f = UnstableFn(f, Math()) - value = p_u_f.eval() + value = p_u_f.value assert isinstance(value, partial) assert value.func == f assert value.args == (Math(),) @@ -528,7 +527,7 @@ def f(x: Math) -> Math: ... def test_eval_fn(): - assert EGraph().extract(py_eval_fn(lambda x: (x,))(PyObject.from_int(1))).eval() == (1,) + assert EGraph().extract(py_eval_fn(lambda x: (x,))(PyObject.from_int(1))).value == (1,) def _global_make_tuple(x): @@ -536,14 +535,14 @@ def _global_make_tuple(x): def test_eval_fn_globals(): - assert EGraph().extract(py_eval_fn(lambda x: _global_make_tuple(x))(PyObject.from_int(1))).eval() == (1,) + assert EGraph().extract(py_eval_fn(lambda x: _global_make_tuple(x))(PyObject.from_int(1))).value == (1,) def test_eval_fn_locals(): def _locals_make_tuple(x): return (x,) - assert EGraph().extract(py_eval_fn(lambda x: _locals_make_tuple(x))(PyObject.from_int(1))).eval() == (1,) + assert EGraph().extract(py_eval_fn(lambda x: _locals_make_tuple(x))(PyObject.from_int(1))).value == (1,) def test_lazy_types(): @@ -835,6 +834,106 @@ def __eq__(self, other: B) -> B: ... # type: ignore[override] assert not isinstance(B() == B(), Fact) +def test_isinstance_expr(): + """ + Verifies that isinstance() works on Exprs, and returns a Fact + """ + + class A(Expr): + def __init__(self) -> None: ... + + class B(Expr): + def __init__(self) -> None: ... + + assert isinstance(A(), A) + assert not isinstance(A(), B) + + +class TestMatch: + def test_class(self): + """ + Verify that we can pattern match on expressions + """ + + class A(Expr): + def __init__(self) -> None: ... + + class B(Expr): + def __init__(self) -> None: ... + + a = A() + match a: + case B(): + msg = "Should not have matched B" + raise ValueError(msg) + case A(): + pass + case _: + msg = "Should have matched A" + raise ValueError(msg) + + def test_literal(self): + match i64(10): + case i64(i): + assert i == 10 + case _: + msg = "Should have matched i64(10)" + raise ValueError(msg) + + def test_literal_fail(self): + """ + Verify that matching on a literal that does not match raises an error + """ + match i64(10) + i64(10): + case i64(_i): + msg = "Should not have matched i64(20)" + raise ValueError(msg) + + def test_custom_args(self): + class A(Expr): + def __init__(self) -> None: ... + + __match_args__ = ("a", "b") + + @method(preserve=True) # type: ignore[misc] + @property + def a(self) -> int: + return 1 + + @method(preserve=True) # type: ignore[misc] + @property + def b(self) -> str: + return "hi" + + match A(): + case A(a, b): + assert a == 1 + assert b == "hi" + case _: + msg = "Should have matched A" + raise ValueError(msg) + + def test_custom_args_fail(self): + """ + Verify that matching on a custom match that does not match raises an error + """ + + class A(Expr): + def __init__(self) -> None: ... + + __match_args__ = ("a",) + + @method(preserve=True) # type: ignore[misc] + @property + def a(self) -> int: + raise AttributeError + + match A(): + case A(_a): + msg = "Should not have matched A" + raise ValueError(msg) + + T = TypeVar("T") diff --git a/python/tests/test_program_gen.py b/python/tests/test_program_gen.py index 0ddde294..c6113128 100644 --- a/python/tests/test_program_gen.py +++ b/python/tests/test_program_gen.py @@ -60,7 +60,7 @@ def test_to_string(snapshot_py) -> None: egraph.register(fn.compile()) egraph.run((to_program_ruleset | program_gen_ruleset).saturate()) egraph.check(fn.expr == String("my_fn")) - assert egraph.extract(fn.statements).eval() == snapshot_py + assert egraph.extract(fn.statements).value == snapshot_py def test_to_string_function_three(snapshot_py) -> None: @@ -71,8 +71,8 @@ def test_to_string_function_three(snapshot_py) -> None: egraph = EGraph() egraph.register(fn.compile()) egraph.run((to_program_ruleset | program_gen_ruleset).saturate()) - assert egraph.extract(fn.expr).eval() == "my_fn" - assert egraph.extract(fn.statements).eval() == snapshot_py + assert egraph.extract(fn.expr).value == "my_fn" + assert egraph.extract(fn.statements).value == snapshot_py def test_py_object(): @@ -84,6 +84,6 @@ def test_py_object(): egraph = EGraph() egraph.register(evalled) egraph.run((to_program_ruleset | eval_program_rulseset | program_gen_ruleset).saturate()) - res = cast("FunctionType", egraph.extract(evalled.as_py_object).eval()) + res = cast("FunctionType", egraph.extract(evalled.as_py_object).value) assert res(1, 2) == 13 assert inspect.getsource(res)