diff --git a/.flake8 b/.flake8 index 8de4fc94..6addb2e2 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,4 @@ ; https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html#flake8 [flake8] max-line-length = 88 -extend-ignore = E203,E501,F405,F403,E302 \ No newline at end of file +extend-ignore = E203,E501,F405,F403,E302,E305,F821 \ No newline at end of file diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 63c81f5e..b77c918e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -38,7 +38,7 @@ jobs: python-version: "3.10" - uses: actions/checkout@v2 - run: pip install -e . mypy - - run: python -m mypy.stubtest egg_smol.bindings + - run: python -m mypy.stubtest egg_smol.bindings --allowlist stubtest_allow docs: runs-on: ubuntu-latest steps: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b7c2b0af..97e17199 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,7 +18,3 @@ repos: rev: 5.0.4 hooks: - id: flake8 - - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v0.982" - hooks: - - id: mypy diff --git a/Cargo.lock b/Cargo.lock index 87f1096b..ed5f0a0e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -148,7 +148,7 @@ dependencies = [ [[package]] name = "egg-smol" version = "0.1.0" -source = "git+https://github.com/saulshanabrook/egg-smol?branch=public-api#cefd8fdcc3b521e9f95bea16bdcfc57c089a16ea" +source = "git+https://github.com/mwillsey/egg-smol#9a45bdee9b5395ab821318a033c6ac7fa91f91b9" dependencies = [ "clap", "env_logger", diff --git a/Cargo.toml b/Cargo.toml index dc03b6ac..7840bcd0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ crate-type = ["cdylib"] [dependencies] pyo3 = { version = "0.17.1", features = ["extension-module"] } -egg-smol = { git = "https://github.com/saulshanabrook/egg-smol", branch = "public-api" } +egg-smol = { git = "https://github.com/mwillsey/egg-smol", ref = "9a45bdee9b5395ab821318a033c6ac7fa91f91b9" } [package.metadata.maturin] name = "egg_smol.bindings" diff --git a/docs/explanation/compared_to_rust.md b/docs/explanation/compared_to_rust.md index d7f2b54f..4e372ec8 100644 --- a/docs/explanation/compared_to_rust.md +++ b/docs/explanation/compared_to_rust.md @@ -52,6 +52,8 @@ One way to run this in Python is to parse the text and run it similar to how the egg CLI works: ```{code-cell} python +from egg_smol.bindings import * + eqsat_basic = """(datatype Math (Num i64) (Var String) @@ -78,8 +80,6 @@ eqsat_basic = """(datatype Math (run 10) (check (= expr1 expr2))""" -from egg_smol.bindings import EGraph - egraph = EGraph() egraph.parse_and_run_program(eqsat_basic) ``` @@ -90,8 +90,6 @@ However, this isn't the most friendly for Python users. Instead, we can use the low level APIs that mirror the rust APIs to build the same egraph: ```{code-cell} python -from egg_smol.bindings_py import * - egraph = EGraph() egraph.declare_sort("Math") egraph.declare_constructor(Variant("Num", ["i64"]), "Math") diff --git a/python/egg_smol/bindings.pyi b/python/egg_smol/bindings.pyi index 3accec5d..e0e2f326 100644 --- a/python/egg_smol/bindings.pyi +++ b/python/egg_smol/bindings.pyi @@ -3,19 +3,147 @@ from typing import Optional from typing_extensions import final -from .bindings_py import Expr, Fact_, FunctionDecl, Rewrite, Variant - @final class EGraph: def parse_and_run_program(self, input: str) -> list[str]: ... def declare_constructor(self, variant: Variant, sort: str) -> None: ... def declare_sort(self, name: str) -> None: ... def declare_function(self, decl: FunctionDecl) -> None: ... - def define(self, name: str, expr: Expr, cost: Optional[int] = None) -> None: ... + def define(self, name: str, expr: _Expr, cost: Optional[int] = None) -> None: ... def add_rewrite(self, rewrite: Rewrite) -> str: ... def run_rules(self, limit: int) -> tuple[timedelta, timedelta, timedelta]: ... - def check_fact(self, fact: Fact_) -> None: ... + def check_fact(self, fact: _Fact) -> None: ... @final class EggSmolError(Exception): context: str + +@final +class Int: + def __init__(self, value: int) -> None: ... + value: int + +@final +class String: + def __init__(self, value: str) -> None: ... + value: str + +@final +class Unit: + def __init__(self) -> None: ... + +_Literal = Int | String | Unit + +@final +class Lit: + def __init__(self, value: _Literal) -> None: ... + value: _Literal + +@final +class Var: + def __init__(self, name: str) -> None: ... + name: str + +@final +class Call: + def __init__(self, name: str, args: list[_Expr]) -> None: ... + name: str + args: list[_Expr] + +_Expr = Lit | Var | Call + +@final +class Eq: + def __init__(self, exprs: list[_Expr]) -> None: ... + exprs: list[_Expr] + +@final +class Fact: + def __init__(self, expr: _Expr) -> None: ... + expr: _Expr + +_Fact = Fact | Eq + +@final +class Define: + def __init__(self, lhs: str, rhs: _Expr) -> None: ... + lhs: str + rhs: _Expr + +@final +class Set: + def __init__(self, lhs: str, args: list[_Expr], rhs: _Expr) -> None: ... + lhs: str + args: list[_Expr] + rhs: _Expr + +@final +class Delete: + sym: str + args: list[_Expr] + def __init__(self, sym: str, args: list[_Expr]) -> None: ... + +@final +class Union: + def __init__(self, lhs: _Expr, rhs: _Expr) -> None: ... + lhs: _Expr + rhs: _Expr + +@final +class Panic: + def __init__(self, msg: str) -> None: ... + msg: str + +@final +class Expr_: + def __init__(self, expr: _Expr) -> None: ... + expr: _Expr + +_Action = Define | Set | Delete | Union | Panic | Expr_ + +@final +class FunctionDecl: + name: str + schema: Schema + default: Optional[_Expr] + merge: Optional[_Expr] + cost: Optional[int] + def __init__( + self, + name: str, + schema: Schema, + default: Optional[_Expr], + merge: Optional[_Expr], + cost: Optional[int] = None, + ) -> None: ... + +@final +class Variant: + def __init__( + self, name: str, types: list[str], cost: Optional[int] = None + ) -> None: ... + name: str + types: list[str] + cost: Optional[int] + +@final +class Schema: + input: list[str] + output: str + def __init__(self, input: list[str], output: str) -> None: ... + +@final +class Rule: + head: list[_Action] + body: list[_Fact] + def __init__(self, head: list[_Action], body: list[_Fact]) -> None: ... + +@final +class Rewrite: + lhs: _Expr + rhs: _Expr + conditions: list[_Fact] + + def __init__( + self, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = [] + ) -> None: ... diff --git a/python/egg_smol/bindings_py.py b/python/egg_smol/bindings_py.py deleted file mode 100644 index 0b27ea49..00000000 --- a/python/egg_smol/bindings_py.py +++ /dev/null @@ -1,84 +0,0 @@ -# TODO: Figure out what these modules should be called -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Optional, Union - - -@dataclass(frozen=True) -class Variant: - name: str - types: list[str] - cost: Optional[int] = None - - -@dataclass(frozen=True) -class FunctionDecl: - name: str - schema: Schema - default: Optional[Expr] = None - merge: Optional[Expr] = None - cost: Optional[int] = None - - -@dataclass(frozen=True) -class Schema: - input: list[str] - output: str - - -@dataclass(frozen=True) -class Lit: - value: Literal - - -@dataclass(frozen=True) -class Var: - name: str - - -@dataclass(frozen=True) -class Call: - name: str - args: list[Expr] - - -Expr = Union[Lit, Var, Call] - - -@dataclass(frozen=True) -class Int: - value: int - - -@dataclass(frozen=True) -class String: - value: str - - -@dataclass(frozen=True) -class Unit: - pass - - -Literal = Union[Int, String, Unit] - - -@dataclass(frozen=True) -class Rewrite: - lhs: Expr - rhs: Expr - conditions: list[Fact_] = field(default_factory=list) - - -@dataclass(frozen=True) -class Fact: - expr: Expr - - -@dataclass -class Eq: - exprs: list[Expr] - - -Fact_ = Union[Fact, Eq] diff --git a/python/tests/test.py b/python/tests/test.py index 7fd3ac57..3dfdf90e 100644 --- a/python/tests/test.py +++ b/python/tests/test.py @@ -1,8 +1,7 @@ import datetime import pytest -from egg_smol.bindings import EggSmolError, EGraph -from egg_smol.bindings_py import * +from egg_smol.bindings import * class TestEGraph: @@ -334,3 +333,29 @@ def test_check_fact(self): ] ) ) + + # def test_extract(self): + # # Example from extraction-cost + # egraph = EGraph() + # egraph.declare_sort("Expr") + # egraph.declare_constructor(Variant("Num", ["i64"], cost=5), "Expr") + + # egraph.define("x", Call("Num", [Lit(Int(1))]), cost=10) + # egraph.define("y", Call("Num", [Lit(Int(2))]), cost=1) + + # assert egraph.extract("x") == Call("Num", [Lit(Int(1))]) + # assert egraph.extract("y") == Var("y") + + +class TestVariant: + def test_repr(self): + assert repr(Variant("name", [])) == "Variant('name', [], None)" + + def test_name(self): + assert Variant("name", []).name == "name" + + def test_types(self): + assert Variant("name", ["a", "b"]).types == ["a", "b"] + + def test_cost(self): + assert Variant("name", [], cost=1).cost == 1 diff --git a/src/conversions.rs b/src/conversions.rs index c07018ad..d8cb016a 100644 --- a/src/conversions.rs +++ b/src/conversions.rs @@ -1,251 +1,127 @@ -use std::time::Duration; - // Create wrappers around input types so that convert from pyobjects to them // and then from them to the egg_smol types -// -// Converts from Python classes we define in pure python so we can use dataclasses -// to represent the input types -// TODO: Copy strings of these from egg-smol... Maybe actually wrap those isntead. -use pyo3::{ffi::PyDateTime_Delta, prelude::*, types::PyDelta}; - -// Execute the block and wrap the error in a type error -fn wrap_error(tp: &str, obj: &'_ PyAny, block: impl FnOnce() -> PyResult) -> PyResult { - block().map_err(|e| { - PyErr::new::(format!( - "Error converting {} to {}: {}", - obj, tp, e - )) - }) -} - -// Wrapped version of Variant -pub struct WrappedVariant(egg_smol::ast::Variant); - -impl FromPyObject<'_> for WrappedVariant { - fn extract(obj: &'_ PyAny) -> PyResult { - wrap_error("Variant", obj, || { - Ok(WrappedVariant(egg_smol::ast::Variant { - name: obj.getattr("name")?.extract::()?.into(), - cost: obj.getattr("cost")?.extract()?, - types: obj - .getattr("types")? - .extract::>()? - .into_iter() - .map(|x| x.into()) - .collect(), - })) - }) - } -} - -impl From for egg_smol::ast::Variant { - fn from(other: WrappedVariant) -> Self { - other.0 - } -} - -// Wrapped version of FunctionDecl -pub struct WrappedFunctionDecl(egg_smol::ast::FunctionDecl); -impl FromPyObject<'_> for WrappedFunctionDecl { - fn extract(obj: &'_ PyAny) -> PyResult { - wrap_error("FunctionDecl", obj, || { - Ok(WrappedFunctionDecl(egg_smol::ast::FunctionDecl { - name: obj.getattr("name")?.extract::()?.into(), - schema: obj.getattr("schema")?.extract::()?.into(), - default: obj - .getattr("default")? - .extract::>()? - .map(|x| x.into()), - merge: obj - .getattr("merge")? - .extract::>()? - .map(|x| x.into()), - cost: obj.getattr("cost")?.extract()?, - })) - }) - } -} - -impl From for egg_smol::ast::FunctionDecl { - fn from(other: WrappedFunctionDecl) -> Self { - other.0 - } -} - -// Wrapped version of Schema -pub struct WrappedSchema(egg_smol::ast::Schema); - -impl FromPyObject<'_> for WrappedSchema { - fn extract(obj: &'_ PyAny) -> PyResult { - wrap_error("Schema", obj, || { - Ok(WrappedSchema(egg_smol::ast::Schema { - input: obj - .getattr("input")? - .extract::>()? - .into_iter() - .map(|x| x.into()) - .collect(), - output: obj.getattr("output")?.extract::()?.into(), - })) - }) - } -} - -impl From for egg_smol::ast::Schema { - fn from(other: WrappedSchema) -> Self { - other.0 - } -} - -// Wrapped version of Expr -pub struct WrappedExpr(egg_smol::ast::Expr); - -impl FromPyObject<'_> for WrappedExpr { - fn extract(obj: &'_ PyAny) -> PyResult { - wrap_error("Expr", obj, || - // Try extracting into each type of expression, and return the first one that works - extract_expr_lit(obj) - .or_else(|_| extract_expr_call(obj)) - .or_else(|_| extract_expr_var(obj)) - .map(WrappedExpr)) - } -} - -fn extract_expr_lit(obj: &PyAny) -> PyResult { - Ok(egg_smol::ast::Expr::Lit( - obj.getattr("value")?.extract::()?.into(), - )) -} - -fn extract_expr_var(obj: &PyAny) -> PyResult { - Ok(egg_smol::ast::Expr::Var( - obj.getattr("name")?.extract::()?.into(), - )) -} - -fn extract_expr_call(obj: &PyAny) -> PyResult { - Ok(egg_smol::ast::Expr::Call( - obj.getattr("name")?.extract::()?.into(), - obj.getattr("args")? - .extract::>()? - .into_iter() - .map(|x| x.into()) - .collect(), - )) -} - -impl From for egg_smol::ast::Expr { - fn from(other: WrappedExpr) -> Self { - other.0 - } -} - -// Wrapped version of Literal -pub struct WrappedLiteral(egg_smol::ast::Literal); - -impl FromPyObject<'_> for WrappedLiteral { - fn extract(obj: &'_ PyAny) -> PyResult { - wrap_error("Literal", obj, || { - extract_literal_int(obj) - .or_else(|_| extract_literal_string(obj)) - .or_else(|_| extract_literal_unit(obj)) - .map(WrappedLiteral) - }) - } -} - -fn extract_literal_int(obj: &PyAny) -> PyResult { - Ok(egg_smol::ast::Literal::Int( - obj.getattr("value")?.extract()?, - )) -} - -fn extract_literal_string(obj: &PyAny) -> PyResult { - Ok(egg_smol::ast::Literal::String( - obj.getattr("value")?.extract::()?.into(), - )) -} -fn extract_literal_unit(obj: &PyAny) -> PyResult { - if obj.is_none() { - Ok(egg_smol::ast::Literal::Unit) - } else { - Err(pyo3::exceptions::PyTypeError::new_err("Expected None")) - } -} - -impl From for egg_smol::ast::Literal { - fn from(other: WrappedLiteral) -> Self { - other.0 - } -} - -// Wrapped version of Rewrite -pub struct WrappedRewrite(egg_smol::ast::Rewrite); - -impl FromPyObject<'_> for WrappedRewrite { - fn extract(obj: &'_ PyAny) -> PyResult { - wrap_error("Rewrite", obj, || { - Ok(WrappedRewrite(egg_smol::ast::Rewrite { - lhs: obj.getattr("lhs")?.extract::()?.into(), - rhs: obj.getattr("rhs")?.extract::()?.into(), - conditions: obj - .getattr("conditions")? - .extract::>()? - .into_iter() - .map(|x| x.into()) - .collect(), - })) - }) - } -} - -impl From for egg_smol::ast::Rewrite { - fn from(other: WrappedRewrite) -> Self { - other.0 - } -} - -// Wrapped version of Fact -pub struct WrappedFact(egg_smol::ast::Fact); - -impl FromPyObject<'_> for WrappedFact { - fn extract(obj: &'_ PyAny) -> PyResult { - wrap_error("Fact", obj, || { - extract_fact_eq(obj) - .or_else(|_| extract_fact_fact(obj)) - .map(WrappedFact) - }) - } -} - -fn extract_fact_eq(obj: &PyAny) -> PyResult { - Ok(egg_smol::ast::Fact::Eq( - obj.getattr("exprs")? - .extract::>()? - .into_iter() - .map(|x| x.into()) - .collect(), - )) -} - -fn extract_fact_fact(obj: &PyAny) -> PyResult { - Ok(egg_smol::ast::Fact::Fact( - obj.getattr("expr")?.extract::()?.into(), - )) -} - -impl From for egg_smol::ast::Fact { - fn from(other: WrappedFact) -> Self { - other.0 - } -} +use crate::utils::*; +use pyo3::prelude::*; + +convert_enums!( + egg_smol::ast::Literal => Literal { + Int(value: i64) + i -> egg_smol::ast::Literal::Int(i.value), + egg_smol::ast::Literal::Int(i) => Int { value: i.clone() }; + String_[name="String"](value: String) + s -> egg_smol::ast::Literal::String((&s.value).into()), + egg_smol::ast::Literal::String(s) => String_ { value: s.to_string() }; + Unit() + _x -> egg_smol::ast::Literal::Unit, + egg_smol::ast::Literal::Unit => Unit {} + }; + egg_smol::ast::Expr => Expr { + Lit(value: Literal) + l -> egg_smol::ast::Expr::Lit((&l.value).into()), + egg_smol::ast::Expr::Lit(l) => Lit { value: l.into() }; + Var(name: String) + v -> egg_smol::ast::Expr::Var((&v.name).into()), + egg_smol::ast::Expr::Var(v) => Var { name: v.to_string() }; + Call(name: String, args: Vec) + c -> egg_smol::ast::Expr::Call((&c.name).into(), (&c.args).into_iter().map(|e| e.into()).collect()), + egg_smol::ast::Expr::Call(c, a) => Call { + name: c.to_string(), + args: a.into_iter().map(|e| e.into()).collect() + } + }; + egg_smol::ast::Fact => Fact_ { + Eq(exprs: Vec) + eq -> egg_smol::ast::Fact::Eq((&eq.exprs).into_iter().map(|e| e.into()).collect()), + egg_smol::ast::Fact::Eq(e) => Eq { exprs: e.into_iter().map(|e| e.into()).collect() }; + Fact(expr: Expr) + f -> egg_smol::ast::Fact::Fact((&f.expr).into()), + egg_smol::ast::Fact::Fact(e) => Fact { expr: e.into() } + }; + egg_smol::ast::Action => Action { + Define(lhs: String, rhs: Expr) + d -> egg_smol::ast::Action::Define((&d.lhs).into(), (&d.rhs).into()), + egg_smol::ast::Action::Define(n, e) => Define { lhs: n.to_string(), rhs: e.into() }; + Set(lhs: String, args: Vec, rhs: Expr) + s -> egg_smol::ast::Action::Set((&s.lhs).into(), (&s.args).into_iter().map(|e| e.into()).collect(), (&s.rhs).into()), + egg_smol::ast::Action::Set(n, a, e) => Set { + lhs: n.to_string(), + args: a.into_iter().map(|e| e.into()).collect(), + rhs: e.into() + }; + Delete(sym: String, args: Vec) + d -> egg_smol::ast::Action::Delete((&d.sym).into(), (&d.args).into_iter().map(|e| e.into()).collect()), + egg_smol::ast::Action::Delete(n, a) => Delete { + sym: n.to_string(), + args: a.into_iter().map(|e| e.into()).collect() + }; + Union(lhs: Expr, rhs: Expr) + u -> egg_smol::ast::Action::Union((&u.lhs).into(), (&u.rhs).into()), + egg_smol::ast::Action::Union(l, r) => Union { lhs: l.into(), rhs: r.into() }; + Panic(msg: String) + p -> egg_smol::ast::Action::Panic(p.msg.to_string()), + egg_smol::ast::Action::Panic(msg) => Panic { msg: msg.to_string() }; + Expr_(expr: Expr) + e -> egg_smol::ast::Action::Expr((&e.expr).into()), + egg_smol::ast::Action::Expr(e) => Expr_ { expr: e.into() } + } +); + +convert_struct!( + egg_smol::ast::FunctionDecl => FunctionDecl( + name: String, + schema: Schema, + default: Option, + merge: Option, + cost: Option = "None" + ) + f -> egg_smol::ast::FunctionDecl { + name: (&f.name).into(), + schema: (&f.schema).into(), + default: f.default.as_ref().map(|e| e.into()), + merge: f.merge.as_ref().map(|e| e.into()), + cost: f.cost + }, + f -> FunctionDecl { + name: f.name.to_string(), + schema: (&f.schema).into(), + default: f.default.as_ref().map(|e| e.into()), + merge: f.merge.as_ref().map(|e| e.into()), + cost: f.cost + }; + egg_smol::ast::Variant => Variant( + name: String, + types: Vec, + cost: Option = "None" + ) + v -> egg_smol::ast::Variant {name: (&v.name).into(), types: (&v.types).into_iter().map(|v| v.into()).collect(), cost: v.cost}, + v -> Variant {name: v.name.to_string(), types: v.types.iter().map(|v| v.to_string()).collect(), cost: v.cost}; + egg_smol::ast::Schema => Schema( + input: Vec, + output: String + ) + s -> egg_smol::ast::Schema {input: (&s.input).into_iter().map(|v| v.into()).collect(), output: (&s.output).into()}, + s -> Schema {input: s.input.iter().map(|v| v.to_string()).collect(), output: s.output.to_string()}; + egg_smol::ast::Rule[display] => Rule( + head: Vec, + body: Vec + ) + r -> egg_smol::ast::Rule {head: (&r.head).into_iter().map(|v| v.into()).collect(), body: (&r.body).into_iter().map(|v| v.into()).collect()}, + r -> Rule {head: r.head.iter().map(|v| v.into()).collect(), body: r.body.iter().map(|v| v.into()).collect()}; + egg_smol::ast::Rewrite => Rewrite( + lhs: Expr, + rhs: Expr, + conditions: Vec = "Vec::new()" + ) + r -> egg_smol::ast::Rewrite {lhs: (&r.lhs).into(), rhs: (&r.rhs).into(), conditions: (&r.conditions).into_iter().map(|v| v.into()).collect()}, + r -> Rewrite {lhs: (&r.lhs).into(), rhs: (&r.rhs).into(), conditions: r.conditions.iter().map(|v| v.into()).collect()} +); // Wrapped version of Duration // Converts from a rust duration to a python timedelta -pub struct WrappedDuration(Duration); +pub struct WrappedDuration(std::time::Duration); -impl From for WrappedDuration { - fn from(other: Duration) -> Self { +impl From for WrappedDuration { + fn from(other: std::time::Duration) -> Self { WrappedDuration(other) } } @@ -253,7 +129,7 @@ impl From for WrappedDuration { impl IntoPy for WrappedDuration { fn into_py(self, py: Python<'_>) -> PyObject { let d = self.0; - PyDelta::new( + pyo3::types::PyDelta::new( py, 0, 0, diff --git a/src/lib.rs b/src/lib.rs index 70d26e22..260e5eb0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ mod conversions; mod error; +mod utils; use conversions::*; use error::*; @@ -23,15 +24,31 @@ impl EGraph { } } - /// Check that a fact is true in the egraph. + /// Extract the best expression of a given value. Will also return + // variants number of additional options. + // #[pyo3(text_signature = "($self, value, variants=0)")] + // fn extract_expr( + // &mut self, + // value: WrappedExpr, + // variants: usize, + // ) -> EggResult<(usize, WrappedExpr, Vec)> { + // let (cost, expr, exprs) = self.egraph.extract_expr(value.into(), variants)?; + // Ok(( + // cost, + // expr.into(), + // exprs.into_iter().map(|x| x.into()).collect(), + // )) + // } + + // /// Check that a fact is true in the egraph. #[pyo3(text_signature = "($self, fact)")] - fn check_fact(&mut self, fact: WrappedFact) -> EggResult<()> { + fn check_fact(&mut self, fact: Fact_) -> EggResult<()> { self.egraph.check_fact(&fact.into())?; Ok({}) } - /// Run the rules on the egraph until it reaches a fixpoint, specifying the max number of iterations. - /// Returns a tuple of the total time spen searching, applying, and rebuilding. + // /// Run the rules on the egraph until it reaches a fixpoint, specifying the max number of iterations. + // /// Returns a tuple of the total time spen searching, applying, and rebuilding. #[pyo3(text_signature = "($self, limit)")] fn run_rules( &mut self, @@ -41,40 +58,40 @@ impl EGraph { Ok((search.into(), apply.into(), rebuild.into())) } - /// Define a rewrite rule, returning the name of the rule + // /// Define a rewrite rule, returning the name of the rule #[pyo3(text_signature = "($self, rewrite)")] - fn add_rewrite(&mut self, rewrite: WrappedRewrite) -> EggResult { + fn add_rewrite(&mut self, rewrite: Rewrite) -> EggResult { let res = self.egraph.add_rewrite(rewrite.into())?; Ok(res.to_string()) } - /// Define a new named value. + // /// Define a new named value. #[pyo3( text_signature = "($self, name, expr, cost=None)", signature = "(name, expr, cost=None)" )] - fn define(&mut self, name: String, expr: WrappedExpr, cost: Option) -> EggResult<()> { + fn define(&mut self, name: String, expr: Expr, cost: Option) -> EggResult<()> { self.egraph.define(name.into(), expr.into(), cost)?; Ok(()) } - /// Declare a new function definition. + // /// Declare a new function definition. #[pyo3(text_signature = "($self, decl)")] - fn declare_function(&mut self, decl: WrappedFunctionDecl) -> EggResult<()> { + fn declare_function(&mut self, decl: FunctionDecl) -> EggResult<()> { self.egraph.declare_function(&decl.into())?; Ok(()) } - /// Declare a new sort with the given name. + // /// Declare a new sort with the given name. #[pyo3(text_signature = "($self, name)")] fn declare_sort(&mut self, name: &str) -> EggResult<()> { self.egraph.declare_sort(name)?; Ok({}) } - /// Declare a new datatype constructor. + // /// Declare a new datatype constructor. #[pyo3(text_signature = "($self, variant, sort)")] - fn declare_constructor(&mut self, variant: WrappedVariant, sort: &str) -> EggResult<()> { + fn declare_constructor(&mut self, variant: Variant, sort: &str) -> EggResult<()> { self.egraph.declare_constructor(variant.into(), sort)?; Ok({}) } @@ -94,5 +111,9 @@ impl EGraph { fn bindings(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + + add_structs_to_module(m)?; + add_enums_to_module(m)?; + Ok(()) } diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 00000000..d9a18d37 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,221 @@ +use pyo3::prelude::*; + +pub fn display(t: &T) -> String +where + T: Clone, + V: std::fmt::Display + core::convert::From, +{ + format!("{:}", V::from(t.clone())) +} + +// Create a dataclass-like repr, of the name of the class of the object +// called with the repr of the fields +pub fn data_repr( + py: Python, + slf: PyRef, + field_names: Vec<&str>, +) -> PyResult { + let obj = slf.into_py(py); + let class_name: String = obj + .getattr(py, "__class__")? + .getattr(py, "__name__")? + .extract(py)?; + let field_strings: PyResult> = field_names + .iter() + .map(|name| { + obj.getattr(py, *name) + .and_then(|x| x.call_method(py, "__repr__", (), None)?.extract(py)) + }) + .collect(); + Ok(format!("{}({})", class_name, field_strings?.join(", "))) +} + +// Macro to create a wrapper around rust enums. +// We create Python classes for each variant of the enum +// and create a wrapper enum around all variants to enable conversion to/from Python +// and to/from egg_smol +#[macro_export] +macro_rules! convert_enums { + ($( + $from_type:ty => $to_type:ident { + $( + $variant:ident$([name=$py_name:literal])?($($field:ident: $field_type:ty),*) + $from_ident:ident -> $from:expr, + $to_pat:pat => $to:expr + );* + } + );*) => { + $($( + #[pyclass(frozen, module="egg_smol.bindings"$(, name=$py_name)?)] + #[derive(Clone)] + pub struct $variant { + $( + #[pyo3(get)] + $field: $field_type, + )* + } + + #[pymethods] + impl $variant { + #[new] + fn new($($field: $field_type),*) -> Self { + Self { + $($field),* + } + } + + fn __repr__(slf: PyRef<'_, Self>, py: Python) -> PyResult { + data_repr(py, slf, vec![$(stringify!($field)),*]) + } + + fn __str__(&self) -> String { + display::<_, $from_type>(self) + } + } + + impl From<$variant> for $from_type { + fn from($from_ident: $variant) -> $from_type { + $from + } + } + impl From<&$variant> for $from_type { + fn from($from_ident: &$variant) -> $from_type { + $from + } + } + )* + + #[derive(FromPyObject, Clone)] + pub enum $to_type { + $( + $variant($variant), + )* + } + impl IntoPy for $to_type { + fn into_py(self, py: Python<'_>) -> PyObject { + match self { + $( + $to_type::$variant(v) => v.into_py(py), + )* + } + } + } + impl From<$to_type> for $from_type { + fn from(other: $to_type) -> Self { + match other { + $( + $to_type::$variant(v) => v.into(), + )* + } + } + } + + impl From<$from_type> for $to_type { + fn from(other: $from_type) -> Self { + match other { + $( + $to_pat => $to_type::$variant($to), + )* + } + } + } + + impl From<&$to_type> for $from_type { + fn from(other: &$to_type) -> Self { + match other { + $( + $to_type::$variant(v) => v.into(), + )* + } + } + } + + impl From<&$from_type> for $to_type { + fn from(other: &$from_type) -> Self { + match other { + $( + $to_pat => $to_type::$variant($to), + )* + } + } + } + )* + pub fn add_enums_to_module(module: &PyModule) -> PyResult<()> { + $( + $( + module.add_class::<$variant>()?; + )* + )* + Ok(()) + } + }; +} + +#[macro_export] +macro_rules! convert_struct { + ($( + $from_type:ty$([$str_fn:ident])? => $to_type:ident($($field:ident: $field_type:ty$( = $default:literal)?),*) + $from_ident:ident -> $from:expr, + $to_ident:ident -> $to:expr + );*) => { + $( + #[pyclass(frozen, module="egg_smol.bindings")] + #[derive(Clone)] + pub struct $to_type { + $( + #[pyo3(get)] + $field: $field_type, + )* + } + + #[pymethods] + impl $to_type { + #[new] + #[args($($($field = $default)?)*)] + fn new($($field: $field_type),*) -> Self { + Self { + $($field),* + } + } + + fn __repr__(slf: PyRef<'_, Self>, py: Python) -> PyResult { + data_repr(py, slf, vec![$(stringify!($field)),*]) + } + $( + fn __str__(&self) -> String { + $str_fn::<_, $from_type>(self) + } + )? + } + + impl From<&$to_type> for $from_type { + fn from($from_ident: &$to_type) -> $from_type { + $from + } + } + impl From<&$from_type> for $to_type { + fn from($to_ident: &$from_type) -> Self { + $to + } + } + impl From<$to_type> for $from_type { + fn from($from_ident: $to_type) -> $from_type { + $from + } + } + impl From<$from_type> for $to_type { + fn from($to_ident: $from_type) -> Self { + $to + } + } + )* + pub fn add_structs_to_module(module: &PyModule) -> PyResult<()> { + $( + module.add_class::<$to_type>()?; + )* + Ok(()) + } + }; +} +pub use convert_enums; +pub use convert_struct; diff --git a/stubtest_allow b/stubtest_allow new file mode 100644 index 00000000..275dd646 --- /dev/null +++ b/stubtest_allow @@ -0,0 +1 @@ +.*egg_smol.bindings.Unit.__init__.*