From 1a8491038e9f9cb333e463690f6f84020f2a6900 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Sat, 12 Nov 2022 15:42:11 -0500 Subject: [PATCH 1/7] move variant into rust --- python/egg_smol/bindings.pyi | 11 ++++- python/egg_smol/bindings_py.py | 7 ---- python/tests/test.py | 14 ++++++- src/conversions.rs | 77 ++++++++++++++++++++++++++-------- src/lib.rs | 21 ++++++++++ 5 files changed, 104 insertions(+), 26 deletions(-) diff --git a/python/egg_smol/bindings.pyi b/python/egg_smol/bindings.pyi index 3accec5d..ee7bb514 100644 --- a/python/egg_smol/bindings.pyi +++ b/python/egg_smol/bindings.pyi @@ -3,7 +3,16 @@ from typing import Optional from typing_extensions import final -from .bindings_py import Expr, Fact_, FunctionDecl, Rewrite, Variant +from .bindings_py import Expr, Fact_, FunctionDecl, Rewrite + +@final +class Variant: + def __init__( + self, name: str, types: list[str], cost: Optional[int] = None + ) -> None: ... + name: str + types: list[str] + cost: Optional[int] @final class EGraph: diff --git a/python/egg_smol/bindings_py.py b/python/egg_smol/bindings_py.py index 0b27ea49..5bf29997 100644 --- a/python/egg_smol/bindings_py.py +++ b/python/egg_smol/bindings_py.py @@ -5,13 +5,6 @@ from typing import Optional, Union -@dataclass(frozen=True) -class Variant: - name: str - types: list[str] - cost: Optional[int] = None - - @dataclass(frozen=True) class FunctionDecl: name: str diff --git a/python/tests/test.py b/python/tests/test.py index 7fd3ac57..b37617e9 100644 --- a/python/tests/test.py +++ b/python/tests/test.py @@ -1,7 +1,7 @@ import datetime import pytest -from egg_smol.bindings import EggSmolError, EGraph +from egg_smol.bindings import * from egg_smol.bindings_py import * @@ -334,3 +334,15 @@ def test_check_fact(self): ] ) ) + + # def test_extract(self): + # # Example from extraction-cost + # egraph = EGraph() + # egraph.declare_sort("Expr") + # egraph.declare_constructor(Variant("Num", ["i64"], cost=5), "Expr") + + # egraph.define("x", Call("Num", [Lit(Int(1))]), cost=10) + # egraph.define("y", Call("Num", [Lit(Int(2))]), cost=1) + + # assert egraph.extract("x") == Call("Num", [Lit(Int(1))]) + # assert egraph.extract("y") == Var("y") diff --git a/src/conversions.rs b/src/conversions.rs index c07018ad..809b382a 100644 --- a/src/conversions.rs +++ b/src/conversions.rs @@ -6,7 +6,7 @@ use std::time::Duration; // Converts from Python classes we define in pure python so we can use dataclasses // to represent the input types // TODO: Copy strings of these from egg-smol... Maybe actually wrap those isntead. -use pyo3::{ffi::PyDateTime_Delta, prelude::*, types::PyDelta}; +use pyo3::prelude::*; // Execute the block and wrap the error in a type error fn wrap_error(tp: &str, obj: &'_ PyAny, block: impl FnOnce() -> PyResult) -> PyResult { @@ -18,24 +18,55 @@ fn wrap_error(tp: &str, obj: &'_ PyAny, block: impl FnOnce() -> PyResult) }) } -// Wrapped version of Variant -pub struct WrappedVariant(egg_smol::ast::Variant); +// A variant of a constructor -impl FromPyObject<'_> for WrappedVariant { - fn extract(obj: &'_ PyAny) -> PyResult { - wrap_error("Variant", obj, || { - Ok(WrappedVariant(egg_smol::ast::Variant { - name: obj.getattr("name")?.extract::()?.into(), - cost: obj.getattr("cost")?.extract()?, - types: obj - .getattr("types")? - .extract::>()? - .into_iter() - .map(|x| x.into()) - .collect(), - })) +#[pyclass(name = "Variant")] +#[derive(Clone)] +pub(crate) struct WrappedVariant(egg_smol::ast::Variant); + +#[pymethods] +impl WrappedVariant { + #[new] + fn new(name: String, types: Vec, cost: Option) -> Self { + Self(egg_smol::ast::Variant { + name: name.into(), + types: types.into_iter().map(|x| x.into()).collect(), + cost, }) } + #[getter] + fn name(&self) -> &str { + self.0.name.into() + } + #[getter] + fn types(&self) -> Vec { + self.0.types.iter().map(|x| x.to_string()).collect() + } + #[getter] + fn cost(&self) -> Option { + self.0.cost + } + + fn __repr__(&self) -> String { + format!( + "Variant(name={}, types=[{}], cost={})", + self.0.name.to_string(), + self.0 + .types + .iter() + .map(|x| x.to_string()) + .collect::>() + .join(", "), + match self.0.cost { + Some(x) => x.to_string(), + None => "None".to_string(), + } + ) + } + + fn __str__(&self) -> String { + format!("{:#?}", self.0) + } } impl From for egg_smol::ast::Variant { @@ -44,6 +75,12 @@ impl From for egg_smol::ast::Variant { } } +impl From for WrappedVariant { + fn from(other: egg_smol::ast::Variant) -> Self { + WrappedVariant(other) + } +} + // Wrapped version of FunctionDecl pub struct WrappedFunctionDecl(egg_smol::ast::FunctionDecl); impl FromPyObject<'_> for WrappedFunctionDecl { @@ -140,6 +177,12 @@ impl From for egg_smol::ast::Expr { } } +impl From for WrappedExpr { + fn from(other: egg_smol::ast::Expr) -> Self { + WrappedExpr(other) + } +} + // Wrapped version of Literal pub struct WrappedLiteral(egg_smol::ast::Literal); @@ -253,7 +296,7 @@ impl From for WrappedDuration { impl IntoPy for WrappedDuration { fn into_py(self, py: Python<'_>) -> PyObject { let d = self.0; - PyDelta::new( + pyo3::types::PyDelta::new( py, 0, 0, diff --git a/src/lib.rs b/src/lib.rs index 70d26e22..0cc3c3d6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,6 +23,26 @@ impl EGraph { } } + // fn __repr__(&mut self) -> PyResult { + // Ok(format!("{:#?}", self.egraph)) + // } + + /// Extract the best expression of a given value. Will also return + // variants number of additional options. + // #[pyo3(text_signature = "($self, value, variants=0)")] + // fn extract_expr( + // &mut self, + // value: WrappedExpr, + // variants: usize, + // ) -> EggResult<(usize, WrappedExpr, Vec)> { + // let (cost, expr, exprs) = self.egraph.extract_expr(value.into(), variants)?; + // Ok(( + // cost, + // expr.into(), + // exprs.into_iter().map(|x| x.into()).collect(), + // )) + // } + /// Check that a fact is true in the egraph. #[pyo3(text_signature = "($self, fact)")] fn check_fact(&mut self, fact: WrappedFact) -> EggResult<()> { @@ -94,5 +114,6 @@ impl EGraph { fn bindings(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } From d89435559b025c59afc86427c2524c2b60cb9abf Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Sun, 13 Nov 2022 19:28:48 -0500 Subject: [PATCH 2/7] Fix docs --- docs/explanation/compared_to_rust.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/explanation/compared_to_rust.md b/docs/explanation/compared_to_rust.md index d7f2b54f..4598044c 100644 --- a/docs/explanation/compared_to_rust.md +++ b/docs/explanation/compared_to_rust.md @@ -52,6 +52,9 @@ One way to run this in Python is to parse the text and run it similar to how the egg CLI works: ```{code-cell} python +from egg_smol.bindings import * +from egg_smol.bindings_py import * + eqsat_basic = """(datatype Math (Num i64) (Var String) @@ -78,8 +81,6 @@ eqsat_basic = """(datatype Math (run 10) (check (= expr1 expr2))""" -from egg_smol.bindings import EGraph - egraph = EGraph() egraph.parse_and_run_program(eqsat_basic) ``` @@ -90,8 +91,6 @@ However, this isn't the most friendly for Python users. Instead, we can use the low level APIs that mirror the rust APIs to build the same egraph: ```{code-cell} python -from egg_smol.bindings_py import * - egraph = EGraph() egraph.declare_sort("Math") egraph.declare_constructor(Variant("Num", ["i64"]), "Math") From 2e645e7396895ed8128b6734006ae1cf3b2c0ecf Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Sun, 13 Nov 2022 22:47:07 -0500 Subject: [PATCH 3/7] fix string --- .pre-commit-config.yaml | 4 ---- python/tests/test.py | 14 ++++++++++++++ src/conversions.rs | 38 ++++++++++++++++++++++---------------- 3 files changed, 36 insertions(+), 20 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b7c2b0af..97e17199 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,7 +18,3 @@ repos: rev: 5.0.4 hooks: - id: flake8 - - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v0.982" - hooks: - - id: mypy diff --git a/python/tests/test.py b/python/tests/test.py index b37617e9..d71c6831 100644 --- a/python/tests/test.py +++ b/python/tests/test.py @@ -346,3 +346,17 @@ def test_check_fact(self): # assert egraph.extract("x") == Call("Num", [Lit(Int(1))]) # assert egraph.extract("y") == Var("y") + + +class TestVariant: + def test_repr(self): + assert repr(Variant("name", [])) == "Variant('name', [], None)" + + def test_name(self): + assert Variant("name", []).name == "name" + + def test_types(self): + assert Variant("name", ["a", "b"]).types == ["a", "b"] + + def test_cost(self): + assert Variant("name", [], cost=1).cost == 1 diff --git a/src/conversions.rs b/src/conversions.rs index 809b382a..f67fb764 100644 --- a/src/conversions.rs +++ b/src/conversions.rs @@ -18,7 +18,26 @@ fn wrap_error(tp: &str, obj: &'_ PyAny, block: impl FnOnce() -> PyResult) }) } -// A variant of a constructor +// Take the repr of a Python object +fn repr(py: Python, obj: PyObject) -> PyResult { + obj.call_method(py, "__repr__", (), None)?.extract(py) +} + +trait SizedIntoPy: IntoPy + Sized {} + +// Create a dataclass-like repr, of the name of the class of the object +// called with the repr of the fields +fn data_repr(py: Python, obj: PyObject, field_names: Vec<&str>) -> PyResult { + let class_name: String = obj + .getattr(py, "__class__")? + .getattr(py, "__name__")? + .extract(py)?; + let field_strings: PyResult> = field_names + .iter() + .map(|name| obj.getattr(py, *name).and_then(|x| repr(py, x))) + .collect(); + Ok(format!("{}({})", class_name, field_strings?.join(", "))) +} #[pyclass(name = "Variant")] #[derive(Clone)] @@ -47,21 +66,8 @@ impl WrappedVariant { self.0.cost } - fn __repr__(&self) -> String { - format!( - "Variant(name={}, types=[{}], cost={})", - self.0.name.to_string(), - self.0 - .types - .iter() - .map(|x| x.to_string()) - .collect::>() - .join(", "), - match self.0.cost { - Some(x) => x.to_string(), - None => "None".to_string(), - } - ) + fn __repr__(slf: PyRef<'_, Self>, py: Python) -> PyResult { + data_repr(py, slf.into_py(py), vec!["name", "types", "cost"]) } fn __str__(&self) -> String { From 4365f3d25ada53b7fd50552deaf4304747fa41d9 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Sun, 13 Nov 2022 23:07:49 -0500 Subject: [PATCH 4/7] Minimize size --- src/conversions.rs | 20 +++----------------- src/lib.rs | 6 +++--- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/src/conversions.rs b/src/conversions.rs index f67fb764..c69627fb 100644 --- a/src/conversions.rs +++ b/src/conversions.rs @@ -23,8 +23,6 @@ fn repr(py: Python, obj: PyObject) -> PyResult { obj.call_method(py, "__repr__", (), None)?.extract(py) } -trait SizedIntoPy: IntoPy + Sized {} - // Create a dataclass-like repr, of the name of the class of the object // called with the repr of the fields fn data_repr(py: Python, obj: PyObject, field_names: Vec<&str>) -> PyResult { @@ -39,12 +37,12 @@ fn data_repr(py: Python, obj: PyObject, field_names: Vec<&str>) -> PyResult, cost: Option) -> Self { Self(egg_smol::ast::Variant { @@ -75,18 +73,6 @@ impl WrappedVariant { } } -impl From for egg_smol::ast::Variant { - fn from(other: WrappedVariant) -> Self { - other.0 - } -} - -impl From for WrappedVariant { - fn from(other: egg_smol::ast::Variant) -> Self { - WrappedVariant(other) - } -} - // Wrapped version of FunctionDecl pub struct WrappedFunctionDecl(egg_smol::ast::FunctionDecl); impl FromPyObject<'_> for WrappedFunctionDecl { diff --git a/src/lib.rs b/src/lib.rs index 0cc3c3d6..3cbb768c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -94,8 +94,8 @@ impl EGraph { /// Declare a new datatype constructor. #[pyo3(text_signature = "($self, variant, sort)")] - fn declare_constructor(&mut self, variant: WrappedVariant, sort: &str) -> EggResult<()> { - self.egraph.declare_constructor(variant.into(), sort)?; + fn declare_constructor(&mut self, variant: Variant, sort: &str) -> EggResult<()> { + self.egraph.declare_constructor(variant.0, sort)?; Ok({}) } @@ -114,6 +114,6 @@ impl EGraph { fn bindings(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; - m.add_class::()?; + m.add_class::()?; Ok(()) } From bca8b65081d10530744ff1e650133552d831fc10 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Thu, 17 Nov 2022 18:53:23 -0500 Subject: [PATCH 5/7] Try switching entirely to rust --- .flake8 | 2 +- .github/workflows/CI.yml | 2 +- python/egg_smol/bindings.pyi | 143 +++++++++++- python/egg_smol/bindings_py.py | 77 ------- python/tests/test.py | 1 - src/conversions.rs | 391 ++++++++++----------------------- src/lib.rs | 36 +-- src/utils.rs | 221 +++++++++++++++++++ stubtest_allow | 1 + 9 files changed, 489 insertions(+), 385 deletions(-) delete mode 100644 python/egg_smol/bindings_py.py create mode 100644 src/utils.rs create mode 100644 stubtest_allow diff --git a/.flake8 b/.flake8 index 8de4fc94..6addb2e2 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,4 @@ ; https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html#flake8 [flake8] max-line-length = 88 -extend-ignore = E203,E501,F405,F403,E302 \ No newline at end of file +extend-ignore = E203,E501,F405,F403,E302,E305,F821 \ No newline at end of file diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 63c81f5e..b77c918e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -38,7 +38,7 @@ jobs: python-version: "3.10" - uses: actions/checkout@v2 - run: pip install -e . mypy - - run: python -m mypy.stubtest egg_smol.bindings + - run: python -m mypy.stubtest egg_smol.bindings --allowlist stubtest_allow docs: runs-on: ubuntu-latest steps: diff --git a/python/egg_smol/bindings.pyi b/python/egg_smol/bindings.pyi index ee7bb514..e0e2f326 100644 --- a/python/egg_smol/bindings.pyi +++ b/python/egg_smol/bindings.pyi @@ -3,7 +3,119 @@ from typing import Optional from typing_extensions import final -from .bindings_py import Expr, Fact_, FunctionDecl, Rewrite +@final +class EGraph: + def parse_and_run_program(self, input: str) -> list[str]: ... + def declare_constructor(self, variant: Variant, sort: str) -> None: ... + def declare_sort(self, name: str) -> None: ... + def declare_function(self, decl: FunctionDecl) -> None: ... + def define(self, name: str, expr: _Expr, cost: Optional[int] = None) -> None: ... + def add_rewrite(self, rewrite: Rewrite) -> str: ... + def run_rules(self, limit: int) -> tuple[timedelta, timedelta, timedelta]: ... + def check_fact(self, fact: _Fact) -> None: ... + +@final +class EggSmolError(Exception): + context: str + +@final +class Int: + def __init__(self, value: int) -> None: ... + value: int + +@final +class String: + def __init__(self, value: str) -> None: ... + value: str + +@final +class Unit: + def __init__(self) -> None: ... + +_Literal = Int | String | Unit + +@final +class Lit: + def __init__(self, value: _Literal) -> None: ... + value: _Literal + +@final +class Var: + def __init__(self, name: str) -> None: ... + name: str + +@final +class Call: + def __init__(self, name: str, args: list[_Expr]) -> None: ... + name: str + args: list[_Expr] + +_Expr = Lit | Var | Call + +@final +class Eq: + def __init__(self, exprs: list[_Expr]) -> None: ... + exprs: list[_Expr] + +@final +class Fact: + def __init__(self, expr: _Expr) -> None: ... + expr: _Expr + +_Fact = Fact | Eq + +@final +class Define: + def __init__(self, lhs: str, rhs: _Expr) -> None: ... + lhs: str + rhs: _Expr + +@final +class Set: + def __init__(self, lhs: str, args: list[_Expr], rhs: _Expr) -> None: ... + lhs: str + args: list[_Expr] + rhs: _Expr + +@final +class Delete: + sym: str + args: list[_Expr] + def __init__(self, sym: str, args: list[_Expr]) -> None: ... + +@final +class Union: + def __init__(self, lhs: _Expr, rhs: _Expr) -> None: ... + lhs: _Expr + rhs: _Expr + +@final +class Panic: + def __init__(self, msg: str) -> None: ... + msg: str + +@final +class Expr_: + def __init__(self, expr: _Expr) -> None: ... + expr: _Expr + +_Action = Define | Set | Delete | Union | Panic | Expr_ + +@final +class FunctionDecl: + name: str + schema: Schema + default: Optional[_Expr] + merge: Optional[_Expr] + cost: Optional[int] + def __init__( + self, + name: str, + schema: Schema, + default: Optional[_Expr], + merge: Optional[_Expr], + cost: Optional[int] = None, + ) -> None: ... @final class Variant: @@ -15,16 +127,23 @@ class Variant: cost: Optional[int] @final -class EGraph: - def parse_and_run_program(self, input: str) -> list[str]: ... - def declare_constructor(self, variant: Variant, sort: str) -> None: ... - def declare_sort(self, name: str) -> None: ... - def declare_function(self, decl: FunctionDecl) -> None: ... - def define(self, name: str, expr: Expr, cost: Optional[int] = None) -> None: ... - def add_rewrite(self, rewrite: Rewrite) -> str: ... - def run_rules(self, limit: int) -> tuple[timedelta, timedelta, timedelta]: ... - def check_fact(self, fact: Fact_) -> None: ... +class Schema: + input: list[str] + output: str + def __init__(self, input: list[str], output: str) -> None: ... @final -class EggSmolError(Exception): - context: str +class Rule: + head: list[_Action] + body: list[_Fact] + def __init__(self, head: list[_Action], body: list[_Fact]) -> None: ... + +@final +class Rewrite: + lhs: _Expr + rhs: _Expr + conditions: list[_Fact] + + def __init__( + self, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = [] + ) -> None: ... diff --git a/python/egg_smol/bindings_py.py b/python/egg_smol/bindings_py.py deleted file mode 100644 index 5bf29997..00000000 --- a/python/egg_smol/bindings_py.py +++ /dev/null @@ -1,77 +0,0 @@ -# TODO: Figure out what these modules should be called -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Optional, Union - - -@dataclass(frozen=True) -class FunctionDecl: - name: str - schema: Schema - default: Optional[Expr] = None - merge: Optional[Expr] = None - cost: Optional[int] = None - - -@dataclass(frozen=True) -class Schema: - input: list[str] - output: str - - -@dataclass(frozen=True) -class Lit: - value: Literal - - -@dataclass(frozen=True) -class Var: - name: str - - -@dataclass(frozen=True) -class Call: - name: str - args: list[Expr] - - -Expr = Union[Lit, Var, Call] - - -@dataclass(frozen=True) -class Int: - value: int - - -@dataclass(frozen=True) -class String: - value: str - - -@dataclass(frozen=True) -class Unit: - pass - - -Literal = Union[Int, String, Unit] - - -@dataclass(frozen=True) -class Rewrite: - lhs: Expr - rhs: Expr - conditions: list[Fact_] = field(default_factory=list) - - -@dataclass(frozen=True) -class Fact: - expr: Expr - - -@dataclass -class Eq: - exprs: list[Expr] - - -Fact_ = Union[Fact, Eq] diff --git a/python/tests/test.py b/python/tests/test.py index d71c6831..3dfdf90e 100644 --- a/python/tests/test.py +++ b/python/tests/test.py @@ -2,7 +2,6 @@ import pytest from egg_smol.bindings import * -from egg_smol.bindings_py import * class TestEGraph: diff --git a/src/conversions.rs b/src/conversions.rs index c69627fb..d8cb016a 100644 --- a/src/conversions.rs +++ b/src/conversions.rs @@ -1,286 +1,127 @@ -use std::time::Duration; - // Create wrappers around input types so that convert from pyobjects to them // and then from them to the egg_smol types -// -// Converts from Python classes we define in pure python so we can use dataclasses -// to represent the input types -// TODO: Copy strings of these from egg-smol... Maybe actually wrap those isntead. +use crate::utils::*; use pyo3::prelude::*; -// Execute the block and wrap the error in a type error -fn wrap_error(tp: &str, obj: &'_ PyAny, block: impl FnOnce() -> PyResult) -> PyResult { - block().map_err(|e| { - PyErr::new::(format!( - "Error converting {} to {}: {}", - obj, tp, e - )) - }) -} - -// Take the repr of a Python object -fn repr(py: Python, obj: PyObject) -> PyResult { - obj.call_method(py, "__repr__", (), None)?.extract(py) -} - -// Create a dataclass-like repr, of the name of the class of the object -// called with the repr of the fields -fn data_repr(py: Python, obj: PyObject, field_names: Vec<&str>) -> PyResult { - let class_name: String = obj - .getattr(py, "__class__")? - .getattr(py, "__name__")? - .extract(py)?; - let field_strings: PyResult> = field_names - .iter() - .map(|name| obj.getattr(py, *name).and_then(|x| repr(py, x))) - .collect(); - Ok(format!("{}({})", class_name, field_strings?.join(", "))) -} - -#[pyclass] -#[derive(Clone)] -pub struct Variant(pub egg_smol::ast::Variant); - -#[pymethods] -impl Variant { - #[new] - fn new(name: String, types: Vec, cost: Option) -> Self { - Self(egg_smol::ast::Variant { - name: name.into(), - types: types.into_iter().map(|x| x.into()).collect(), - cost, - }) - } - #[getter] - fn name(&self) -> &str { - self.0.name.into() - } - #[getter] - fn types(&self) -> Vec { - self.0.types.iter().map(|x| x.to_string()).collect() - } - #[getter] - fn cost(&self) -> Option { - self.0.cost - } - - fn __repr__(slf: PyRef<'_, Self>, py: Python) -> PyResult { - data_repr(py, slf.into_py(py), vec!["name", "types", "cost"]) - } - - fn __str__(&self) -> String { - format!("{:#?}", self.0) - } -} - -// Wrapped version of FunctionDecl -pub struct WrappedFunctionDecl(egg_smol::ast::FunctionDecl); -impl FromPyObject<'_> for WrappedFunctionDecl { - fn extract(obj: &'_ PyAny) -> PyResult { - wrap_error("FunctionDecl", obj, || { - Ok(WrappedFunctionDecl(egg_smol::ast::FunctionDecl { - name: obj.getattr("name")?.extract::()?.into(), - schema: obj.getattr("schema")?.extract::()?.into(), - default: obj - .getattr("default")? - .extract::>()? - .map(|x| x.into()), - merge: obj - .getattr("merge")? - .extract::>()? - .map(|x| x.into()), - cost: obj.getattr("cost")?.extract()?, - })) - }) - } -} - -impl From for egg_smol::ast::FunctionDecl { - fn from(other: WrappedFunctionDecl) -> Self { - other.0 - } -} - -// Wrapped version of Schema -pub struct WrappedSchema(egg_smol::ast::Schema); - -impl FromPyObject<'_> for WrappedSchema { - fn extract(obj: &'_ PyAny) -> PyResult { - wrap_error("Schema", obj, || { - Ok(WrappedSchema(egg_smol::ast::Schema { - input: obj - .getattr("input")? - .extract::>()? - .into_iter() - .map(|x| x.into()) - .collect(), - output: obj.getattr("output")?.extract::()?.into(), - })) - }) - } -} - -impl From for egg_smol::ast::Schema { - fn from(other: WrappedSchema) -> Self { - other.0 - } -} - -// Wrapped version of Expr -pub struct WrappedExpr(egg_smol::ast::Expr); - -impl FromPyObject<'_> for WrappedExpr { - fn extract(obj: &'_ PyAny) -> PyResult { - wrap_error("Expr", obj, || - // Try extracting into each type of expression, and return the first one that works - extract_expr_lit(obj) - .or_else(|_| extract_expr_call(obj)) - .or_else(|_| extract_expr_var(obj)) - .map(WrappedExpr)) - } -} - -fn extract_expr_lit(obj: &PyAny) -> PyResult { - Ok(egg_smol::ast::Expr::Lit( - obj.getattr("value")?.extract::()?.into(), - )) -} - -fn extract_expr_var(obj: &PyAny) -> PyResult { - Ok(egg_smol::ast::Expr::Var( - obj.getattr("name")?.extract::()?.into(), - )) -} - -fn extract_expr_call(obj: &PyAny) -> PyResult { - Ok(egg_smol::ast::Expr::Call( - obj.getattr("name")?.extract::()?.into(), - obj.getattr("args")? - .extract::>()? - .into_iter() - .map(|x| x.into()) - .collect(), - )) -} - -impl From for egg_smol::ast::Expr { - fn from(other: WrappedExpr) -> Self { - other.0 - } -} - -impl From for WrappedExpr { - fn from(other: egg_smol::ast::Expr) -> Self { - WrappedExpr(other) - } -} - -// Wrapped version of Literal -pub struct WrappedLiteral(egg_smol::ast::Literal); - -impl FromPyObject<'_> for WrappedLiteral { - fn extract(obj: &'_ PyAny) -> PyResult { - wrap_error("Literal", obj, || { - extract_literal_int(obj) - .or_else(|_| extract_literal_string(obj)) - .or_else(|_| extract_literal_unit(obj)) - .map(WrappedLiteral) - }) - } -} - -fn extract_literal_int(obj: &PyAny) -> PyResult { - Ok(egg_smol::ast::Literal::Int( - obj.getattr("value")?.extract()?, - )) -} - -fn extract_literal_string(obj: &PyAny) -> PyResult { - Ok(egg_smol::ast::Literal::String( - obj.getattr("value")?.extract::()?.into(), - )) -} -fn extract_literal_unit(obj: &PyAny) -> PyResult { - if obj.is_none() { - Ok(egg_smol::ast::Literal::Unit) - } else { - Err(pyo3::exceptions::PyTypeError::new_err("Expected None")) - } -} - -impl From for egg_smol::ast::Literal { - fn from(other: WrappedLiteral) -> Self { - other.0 - } -} - -// Wrapped version of Rewrite -pub struct WrappedRewrite(egg_smol::ast::Rewrite); - -impl FromPyObject<'_> for WrappedRewrite { - fn extract(obj: &'_ PyAny) -> PyResult { - wrap_error("Rewrite", obj, || { - Ok(WrappedRewrite(egg_smol::ast::Rewrite { - lhs: obj.getattr("lhs")?.extract::()?.into(), - rhs: obj.getattr("rhs")?.extract::()?.into(), - conditions: obj - .getattr("conditions")? - .extract::>()? - .into_iter() - .map(|x| x.into()) - .collect(), - })) - }) - } -} - -impl From for egg_smol::ast::Rewrite { - fn from(other: WrappedRewrite) -> Self { - other.0 - } -} - -// Wrapped version of Fact -pub struct WrappedFact(egg_smol::ast::Fact); - -impl FromPyObject<'_> for WrappedFact { - fn extract(obj: &'_ PyAny) -> PyResult { - wrap_error("Fact", obj, || { - extract_fact_eq(obj) - .or_else(|_| extract_fact_fact(obj)) - .map(WrappedFact) - }) - } -} - -fn extract_fact_eq(obj: &PyAny) -> PyResult { - Ok(egg_smol::ast::Fact::Eq( - obj.getattr("exprs")? - .extract::>()? - .into_iter() - .map(|x| x.into()) - .collect(), - )) -} - -fn extract_fact_fact(obj: &PyAny) -> PyResult { - Ok(egg_smol::ast::Fact::Fact( - obj.getattr("expr")?.extract::()?.into(), - )) -} - -impl From for egg_smol::ast::Fact { - fn from(other: WrappedFact) -> Self { - other.0 - } -} +convert_enums!( + egg_smol::ast::Literal => Literal { + Int(value: i64) + i -> egg_smol::ast::Literal::Int(i.value), + egg_smol::ast::Literal::Int(i) => Int { value: i.clone() }; + String_[name="String"](value: String) + s -> egg_smol::ast::Literal::String((&s.value).into()), + egg_smol::ast::Literal::String(s) => String_ { value: s.to_string() }; + Unit() + _x -> egg_smol::ast::Literal::Unit, + egg_smol::ast::Literal::Unit => Unit {} + }; + egg_smol::ast::Expr => Expr { + Lit(value: Literal) + l -> egg_smol::ast::Expr::Lit((&l.value).into()), + egg_smol::ast::Expr::Lit(l) => Lit { value: l.into() }; + Var(name: String) + v -> egg_smol::ast::Expr::Var((&v.name).into()), + egg_smol::ast::Expr::Var(v) => Var { name: v.to_string() }; + Call(name: String, args: Vec) + c -> egg_smol::ast::Expr::Call((&c.name).into(), (&c.args).into_iter().map(|e| e.into()).collect()), + egg_smol::ast::Expr::Call(c, a) => Call { + name: c.to_string(), + args: a.into_iter().map(|e| e.into()).collect() + } + }; + egg_smol::ast::Fact => Fact_ { + Eq(exprs: Vec) + eq -> egg_smol::ast::Fact::Eq((&eq.exprs).into_iter().map(|e| e.into()).collect()), + egg_smol::ast::Fact::Eq(e) => Eq { exprs: e.into_iter().map(|e| e.into()).collect() }; + Fact(expr: Expr) + f -> egg_smol::ast::Fact::Fact((&f.expr).into()), + egg_smol::ast::Fact::Fact(e) => Fact { expr: e.into() } + }; + egg_smol::ast::Action => Action { + Define(lhs: String, rhs: Expr) + d -> egg_smol::ast::Action::Define((&d.lhs).into(), (&d.rhs).into()), + egg_smol::ast::Action::Define(n, e) => Define { lhs: n.to_string(), rhs: e.into() }; + Set(lhs: String, args: Vec, rhs: Expr) + s -> egg_smol::ast::Action::Set((&s.lhs).into(), (&s.args).into_iter().map(|e| e.into()).collect(), (&s.rhs).into()), + egg_smol::ast::Action::Set(n, a, e) => Set { + lhs: n.to_string(), + args: a.into_iter().map(|e| e.into()).collect(), + rhs: e.into() + }; + Delete(sym: String, args: Vec) + d -> egg_smol::ast::Action::Delete((&d.sym).into(), (&d.args).into_iter().map(|e| e.into()).collect()), + egg_smol::ast::Action::Delete(n, a) => Delete { + sym: n.to_string(), + args: a.into_iter().map(|e| e.into()).collect() + }; + Union(lhs: Expr, rhs: Expr) + u -> egg_smol::ast::Action::Union((&u.lhs).into(), (&u.rhs).into()), + egg_smol::ast::Action::Union(l, r) => Union { lhs: l.into(), rhs: r.into() }; + Panic(msg: String) + p -> egg_smol::ast::Action::Panic(p.msg.to_string()), + egg_smol::ast::Action::Panic(msg) => Panic { msg: msg.to_string() }; + Expr_(expr: Expr) + e -> egg_smol::ast::Action::Expr((&e.expr).into()), + egg_smol::ast::Action::Expr(e) => Expr_ { expr: e.into() } + } +); + +convert_struct!( + egg_smol::ast::FunctionDecl => FunctionDecl( + name: String, + schema: Schema, + default: Option, + merge: Option, + cost: Option = "None" + ) + f -> egg_smol::ast::FunctionDecl { + name: (&f.name).into(), + schema: (&f.schema).into(), + default: f.default.as_ref().map(|e| e.into()), + merge: f.merge.as_ref().map(|e| e.into()), + cost: f.cost + }, + f -> FunctionDecl { + name: f.name.to_string(), + schema: (&f.schema).into(), + default: f.default.as_ref().map(|e| e.into()), + merge: f.merge.as_ref().map(|e| e.into()), + cost: f.cost + }; + egg_smol::ast::Variant => Variant( + name: String, + types: Vec, + cost: Option = "None" + ) + v -> egg_smol::ast::Variant {name: (&v.name).into(), types: (&v.types).into_iter().map(|v| v.into()).collect(), cost: v.cost}, + v -> Variant {name: v.name.to_string(), types: v.types.iter().map(|v| v.to_string()).collect(), cost: v.cost}; + egg_smol::ast::Schema => Schema( + input: Vec, + output: String + ) + s -> egg_smol::ast::Schema {input: (&s.input).into_iter().map(|v| v.into()).collect(), output: (&s.output).into()}, + s -> Schema {input: s.input.iter().map(|v| v.to_string()).collect(), output: s.output.to_string()}; + egg_smol::ast::Rule[display] => Rule( + head: Vec, + body: Vec + ) + r -> egg_smol::ast::Rule {head: (&r.head).into_iter().map(|v| v.into()).collect(), body: (&r.body).into_iter().map(|v| v.into()).collect()}, + r -> Rule {head: r.head.iter().map(|v| v.into()).collect(), body: r.body.iter().map(|v| v.into()).collect()}; + egg_smol::ast::Rewrite => Rewrite( + lhs: Expr, + rhs: Expr, + conditions: Vec = "Vec::new()" + ) + r -> egg_smol::ast::Rewrite {lhs: (&r.lhs).into(), rhs: (&r.rhs).into(), conditions: (&r.conditions).into_iter().map(|v| v.into()).collect()}, + r -> Rewrite {lhs: (&r.lhs).into(), rhs: (&r.rhs).into(), conditions: r.conditions.iter().map(|v| v.into()).collect()} +); // Wrapped version of Duration // Converts from a rust duration to a python timedelta -pub struct WrappedDuration(Duration); +pub struct WrappedDuration(std::time::Duration); -impl From for WrappedDuration { - fn from(other: Duration) -> Self { +impl From for WrappedDuration { + fn from(other: std::time::Duration) -> Self { WrappedDuration(other) } } diff --git a/src/lib.rs b/src/lib.rs index 3cbb768c..260e5eb0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ mod conversions; mod error; +mod utils; use conversions::*; use error::*; @@ -23,10 +24,6 @@ impl EGraph { } } - // fn __repr__(&mut self) -> PyResult { - // Ok(format!("{:#?}", self.egraph)) - // } - /// Extract the best expression of a given value. Will also return // variants number of additional options. // #[pyo3(text_signature = "($self, value, variants=0)")] @@ -43,15 +40,15 @@ impl EGraph { // )) // } - /// Check that a fact is true in the egraph. + // /// Check that a fact is true in the egraph. #[pyo3(text_signature = "($self, fact)")] - fn check_fact(&mut self, fact: WrappedFact) -> EggResult<()> { + fn check_fact(&mut self, fact: Fact_) -> EggResult<()> { self.egraph.check_fact(&fact.into())?; Ok({}) } - /// Run the rules on the egraph until it reaches a fixpoint, specifying the max number of iterations. - /// Returns a tuple of the total time spen searching, applying, and rebuilding. + // /// Run the rules on the egraph until it reaches a fixpoint, specifying the max number of iterations. + // /// Returns a tuple of the total time spen searching, applying, and rebuilding. #[pyo3(text_signature = "($self, limit)")] fn run_rules( &mut self, @@ -61,41 +58,41 @@ impl EGraph { Ok((search.into(), apply.into(), rebuild.into())) } - /// Define a rewrite rule, returning the name of the rule + // /// Define a rewrite rule, returning the name of the rule #[pyo3(text_signature = "($self, rewrite)")] - fn add_rewrite(&mut self, rewrite: WrappedRewrite) -> EggResult { + fn add_rewrite(&mut self, rewrite: Rewrite) -> EggResult { let res = self.egraph.add_rewrite(rewrite.into())?; Ok(res.to_string()) } - /// Define a new named value. + // /// Define a new named value. #[pyo3( text_signature = "($self, name, expr, cost=None)", signature = "(name, expr, cost=None)" )] - fn define(&mut self, name: String, expr: WrappedExpr, cost: Option) -> EggResult<()> { + fn define(&mut self, name: String, expr: Expr, cost: Option) -> EggResult<()> { self.egraph.define(name.into(), expr.into(), cost)?; Ok(()) } - /// Declare a new function definition. + // /// Declare a new function definition. #[pyo3(text_signature = "($self, decl)")] - fn declare_function(&mut self, decl: WrappedFunctionDecl) -> EggResult<()> { + fn declare_function(&mut self, decl: FunctionDecl) -> EggResult<()> { self.egraph.declare_function(&decl.into())?; Ok(()) } - /// Declare a new sort with the given name. + // /// Declare a new sort with the given name. #[pyo3(text_signature = "($self, name)")] fn declare_sort(&mut self, name: &str) -> EggResult<()> { self.egraph.declare_sort(name)?; Ok({}) } - /// Declare a new datatype constructor. + // /// Declare a new datatype constructor. #[pyo3(text_signature = "($self, variant, sort)")] fn declare_constructor(&mut self, variant: Variant, sort: &str) -> EggResult<()> { - self.egraph.declare_constructor(variant.0, sort)?; + self.egraph.declare_constructor(variant.into(), sort)?; Ok({}) } @@ -114,6 +111,9 @@ impl EGraph { fn bindings(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; - m.add_class::()?; + + add_structs_to_module(m)?; + add_enums_to_module(m)?; + Ok(()) } diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 00000000..d9a18d37 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,221 @@ +use pyo3::prelude::*; + +pub fn display(t: &T) -> String +where + T: Clone, + V: std::fmt::Display + core::convert::From, +{ + format!("{:}", V::from(t.clone())) +} + +// Create a dataclass-like repr, of the name of the class of the object +// called with the repr of the fields +pub fn data_repr( + py: Python, + slf: PyRef, + field_names: Vec<&str>, +) -> PyResult { + let obj = slf.into_py(py); + let class_name: String = obj + .getattr(py, "__class__")? + .getattr(py, "__name__")? + .extract(py)?; + let field_strings: PyResult> = field_names + .iter() + .map(|name| { + obj.getattr(py, *name) + .and_then(|x| x.call_method(py, "__repr__", (), None)?.extract(py)) + }) + .collect(); + Ok(format!("{}({})", class_name, field_strings?.join(", "))) +} + +// Macro to create a wrapper around rust enums. +// We create Python classes for each variant of the enum +// and create a wrapper enum around all variants to enable conversion to/from Python +// and to/from egg_smol +#[macro_export] +macro_rules! convert_enums { + ($( + $from_type:ty => $to_type:ident { + $( + $variant:ident$([name=$py_name:literal])?($($field:ident: $field_type:ty),*) + $from_ident:ident -> $from:expr, + $to_pat:pat => $to:expr + );* + } + );*) => { + $($( + #[pyclass(frozen, module="egg_smol.bindings"$(, name=$py_name)?)] + #[derive(Clone)] + pub struct $variant { + $( + #[pyo3(get)] + $field: $field_type, + )* + } + + #[pymethods] + impl $variant { + #[new] + fn new($($field: $field_type),*) -> Self { + Self { + $($field),* + } + } + + fn __repr__(slf: PyRef<'_, Self>, py: Python) -> PyResult { + data_repr(py, slf, vec![$(stringify!($field)),*]) + } + + fn __str__(&self) -> String { + display::<_, $from_type>(self) + } + } + + impl From<$variant> for $from_type { + fn from($from_ident: $variant) -> $from_type { + $from + } + } + impl From<&$variant> for $from_type { + fn from($from_ident: &$variant) -> $from_type { + $from + } + } + )* + + #[derive(FromPyObject, Clone)] + pub enum $to_type { + $( + $variant($variant), + )* + } + impl IntoPy for $to_type { + fn into_py(self, py: Python<'_>) -> PyObject { + match self { + $( + $to_type::$variant(v) => v.into_py(py), + )* + } + } + } + impl From<$to_type> for $from_type { + fn from(other: $to_type) -> Self { + match other { + $( + $to_type::$variant(v) => v.into(), + )* + } + } + } + + impl From<$from_type> for $to_type { + fn from(other: $from_type) -> Self { + match other { + $( + $to_pat => $to_type::$variant($to), + )* + } + } + } + + impl From<&$to_type> for $from_type { + fn from(other: &$to_type) -> Self { + match other { + $( + $to_type::$variant(v) => v.into(), + )* + } + } + } + + impl From<&$from_type> for $to_type { + fn from(other: &$from_type) -> Self { + match other { + $( + $to_pat => $to_type::$variant($to), + )* + } + } + } + )* + pub fn add_enums_to_module(module: &PyModule) -> PyResult<()> { + $( + $( + module.add_class::<$variant>()?; + )* + )* + Ok(()) + } + }; +} + +#[macro_export] +macro_rules! convert_struct { + ($( + $from_type:ty$([$str_fn:ident])? => $to_type:ident($($field:ident: $field_type:ty$( = $default:literal)?),*) + $from_ident:ident -> $from:expr, + $to_ident:ident -> $to:expr + );*) => { + $( + #[pyclass(frozen, module="egg_smol.bindings")] + #[derive(Clone)] + pub struct $to_type { + $( + #[pyo3(get)] + $field: $field_type, + )* + } + + #[pymethods] + impl $to_type { + #[new] + #[args($($($field = $default)?)*)] + fn new($($field: $field_type),*) -> Self { + Self { + $($field),* + } + } + + fn __repr__(slf: PyRef<'_, Self>, py: Python) -> PyResult { + data_repr(py, slf, vec![$(stringify!($field)),*]) + } + $( + fn __str__(&self) -> String { + $str_fn::<_, $from_type>(self) + } + )? + } + + impl From<&$to_type> for $from_type { + fn from($from_ident: &$to_type) -> $from_type { + $from + } + } + impl From<&$from_type> for $to_type { + fn from($to_ident: &$from_type) -> Self { + $to + } + } + impl From<$to_type> for $from_type { + fn from($from_ident: $to_type) -> $from_type { + $from + } + } + impl From<$from_type> for $to_type { + fn from($to_ident: $from_type) -> Self { + $to + } + } + )* + pub fn add_structs_to_module(module: &PyModule) -> PyResult<()> { + $( + module.add_class::<$to_type>()?; + )* + Ok(()) + } + }; +} +pub use convert_enums; +pub use convert_struct; diff --git a/stubtest_allow b/stubtest_allow new file mode 100644 index 00000000..275dd646 --- /dev/null +++ b/stubtest_allow @@ -0,0 +1 @@ +.*egg_smol.bindings.Unit.__init__.* From f8bae6468ff702e06028079aada267fd5cf60385 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Thu, 17 Nov 2022 18:55:14 -0500 Subject: [PATCH 6/7] Update cargo version --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 87f1096b..ed5f0a0e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -148,7 +148,7 @@ dependencies = [ [[package]] name = "egg-smol" version = "0.1.0" -source = "git+https://github.com/saulshanabrook/egg-smol?branch=public-api#cefd8fdcc3b521e9f95bea16bdcfc57c089a16ea" +source = "git+https://github.com/mwillsey/egg-smol#9a45bdee9b5395ab821318a033c6ac7fa91f91b9" dependencies = [ "clap", "env_logger", diff --git a/Cargo.toml b/Cargo.toml index dc03b6ac..7840bcd0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ crate-type = ["cdylib"] [dependencies] pyo3 = { version = "0.17.1", features = ["extension-module"] } -egg-smol = { git = "https://github.com/saulshanabrook/egg-smol", branch = "public-api" } +egg-smol = { git = "https://github.com/mwillsey/egg-smol", ref = "9a45bdee9b5395ab821318a033c6ac7fa91f91b9" } [package.metadata.maturin] name = "egg_smol.bindings" From eac7f1dc6b3210f77ce542ca5dc549adf172e147 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Thu, 17 Nov 2022 19:00:35 -0500 Subject: [PATCH 7/7] Fix docs --- docs/explanation/compared_to_rust.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/explanation/compared_to_rust.md b/docs/explanation/compared_to_rust.md index 4598044c..4e372ec8 100644 --- a/docs/explanation/compared_to_rust.md +++ b/docs/explanation/compared_to_rust.md @@ -53,7 +53,6 @@ egg CLI works: ```{code-cell} python from egg_smol.bindings import * -from egg_smol.bindings_py import * eqsat_basic = """(datatype Math (Num i64)