Skip to content

Commit 1a84910

Browse files
move variant into rust
1 parent 71093ea commit 1a84910

File tree

5 files changed

+104
-26
lines changed

5 files changed

+104
-26
lines changed

python/egg_smol/bindings.pyi

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,16 @@ from typing import Optional
33

44
from typing_extensions import final
55

6-
from .bindings_py import Expr, Fact_, FunctionDecl, Rewrite, Variant
6+
from .bindings_py import Expr, Fact_, FunctionDecl, Rewrite
7+
8+
@final
9+
class Variant:
10+
def __init__(
11+
self, name: str, types: list[str], cost: Optional[int] = None
12+
) -> None: ...
13+
name: str
14+
types: list[str]
15+
cost: Optional[int]
716

817
@final
918
class EGraph:

python/egg_smol/bindings_py.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,6 @@
55
from typing import Optional, Union
66

77

8-
@dataclass(frozen=True)
9-
class Variant:
10-
name: str
11-
types: list[str]
12-
cost: Optional[int] = None
13-
14-
158
@dataclass(frozen=True)
169
class FunctionDecl:
1710
name: str

python/tests/test.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import datetime
22

33
import pytest
4-
from egg_smol.bindings import EggSmolError, EGraph
4+
from egg_smol.bindings import *
55
from egg_smol.bindings_py import *
66

77

@@ -334,3 +334,15 @@ def test_check_fact(self):
334334
]
335335
)
336336
)
337+
338+
# def test_extract(self):
339+
# # Example from extraction-cost
340+
# egraph = EGraph()
341+
# egraph.declare_sort("Expr")
342+
# egraph.declare_constructor(Variant("Num", ["i64"], cost=5), "Expr")
343+
344+
# egraph.define("x", Call("Num", [Lit(Int(1))]), cost=10)
345+
# egraph.define("y", Call("Num", [Lit(Int(2))]), cost=1)
346+
347+
# assert egraph.extract("x") == Call("Num", [Lit(Int(1))])
348+
# assert egraph.extract("y") == Var("y")

src/conversions.rs

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::time::Duration;
66
// Converts from Python classes we define in pure python so we can use dataclasses
77
// to represent the input types
88
// TODO: Copy strings of these from egg-smol... Maybe actually wrap those isntead.
9-
use pyo3::{ffi::PyDateTime_Delta, prelude::*, types::PyDelta};
9+
use pyo3::prelude::*;
1010

1111
// Execute the block and wrap the error in a type error
1212
fn wrap_error<T>(tp: &str, obj: &'_ PyAny, block: impl FnOnce() -> PyResult<T>) -> PyResult<T> {
@@ -18,24 +18,55 @@ fn wrap_error<T>(tp: &str, obj: &'_ PyAny, block: impl FnOnce() -> PyResult<T>)
1818
})
1919
}
2020

21-
// Wrapped version of Variant
22-
pub struct WrappedVariant(egg_smol::ast::Variant);
21+
// A variant of a constructor
2322

24-
impl FromPyObject<'_> for WrappedVariant {
25-
fn extract(obj: &'_ PyAny) -> PyResult<Self> {
26-
wrap_error("Variant", obj, || {
27-
Ok(WrappedVariant(egg_smol::ast::Variant {
28-
name: obj.getattr("name")?.extract::<String>()?.into(),
29-
cost: obj.getattr("cost")?.extract()?,
30-
types: obj
31-
.getattr("types")?
32-
.extract::<Vec<String>>()?
33-
.into_iter()
34-
.map(|x| x.into())
35-
.collect(),
36-
}))
23+
#[pyclass(name = "Variant")]
24+
#[derive(Clone)]
25+
pub(crate) struct WrappedVariant(egg_smol::ast::Variant);
26+
27+
#[pymethods]
28+
impl WrappedVariant {
29+
#[new]
30+
fn new(name: String, types: Vec<String>, cost: Option<usize>) -> Self {
31+
Self(egg_smol::ast::Variant {
32+
name: name.into(),
33+
types: types.into_iter().map(|x| x.into()).collect(),
34+
cost,
3735
})
3836
}
37+
#[getter]
38+
fn name(&self) -> &str {
39+
self.0.name.into()
40+
}
41+
#[getter]
42+
fn types(&self) -> Vec<String> {
43+
self.0.types.iter().map(|x| x.to_string()).collect()
44+
}
45+
#[getter]
46+
fn cost(&self) -> Option<usize> {
47+
self.0.cost
48+
}
49+
50+
fn __repr__(&self) -> String {
51+
format!(
52+
"Variant(name={}, types=[{}], cost={})",
53+
self.0.name.to_string(),
54+
self.0
55+
.types
56+
.iter()
57+
.map(|x| x.to_string())
58+
.collect::<Vec<_>>()
59+
.join(", "),
60+
match self.0.cost {
61+
Some(x) => x.to_string(),
62+
None => "None".to_string(),
63+
}
64+
)
65+
}
66+
67+
fn __str__(&self) -> String {
68+
format!("{:#?}", self.0)
69+
}
3970
}
4071

4172
impl From<WrappedVariant> for egg_smol::ast::Variant {
@@ -44,6 +75,12 @@ impl From<WrappedVariant> for egg_smol::ast::Variant {
4475
}
4576
}
4677

78+
impl From<egg_smol::ast::Variant> for WrappedVariant {
79+
fn from(other: egg_smol::ast::Variant) -> Self {
80+
WrappedVariant(other)
81+
}
82+
}
83+
4784
// Wrapped version of FunctionDecl
4885
pub struct WrappedFunctionDecl(egg_smol::ast::FunctionDecl);
4986
impl FromPyObject<'_> for WrappedFunctionDecl {
@@ -140,6 +177,12 @@ impl From<WrappedExpr> for egg_smol::ast::Expr {
140177
}
141178
}
142179

180+
impl From<egg_smol::ast::Expr> for WrappedExpr {
181+
fn from(other: egg_smol::ast::Expr) -> Self {
182+
WrappedExpr(other)
183+
}
184+
}
185+
143186
// Wrapped version of Literal
144187
pub struct WrappedLiteral(egg_smol::ast::Literal);
145188

@@ -253,7 +296,7 @@ impl From<Duration> for WrappedDuration {
253296
impl IntoPy<PyObject> for WrappedDuration {
254297
fn into_py(self, py: Python<'_>) -> PyObject {
255298
let d = self.0;
256-
PyDelta::new(
299+
pyo3::types::PyDelta::new(
257300
py,
258301
0,
259302
0,

src/lib.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,26 @@ impl EGraph {
2323
}
2424
}
2525

26+
// fn __repr__(&mut self) -> PyResult<String> {
27+
// Ok(format!("{:#?}", self.egraph))
28+
// }
29+
30+
/// Extract the best expression of a given value. Will also return
31+
// variants number of additional options.
32+
// #[pyo3(text_signature = "($self, value, variants=0)")]
33+
// fn extract_expr(
34+
// &mut self,
35+
// value: WrappedExpr,
36+
// variants: usize,
37+
// ) -> EggResult<(usize, WrappedExpr, Vec<WrappedExpr>)> {
38+
// let (cost, expr, exprs) = self.egraph.extract_expr(value.into(), variants)?;
39+
// Ok((
40+
// cost,
41+
// expr.into(),
42+
// exprs.into_iter().map(|x| x.into()).collect(),
43+
// ))
44+
// }
45+
2646
/// Check that a fact is true in the egraph.
2747
#[pyo3(text_signature = "($self, fact)")]
2848
fn check_fact(&mut self, fact: WrappedFact) -> EggResult<()> {
@@ -94,5 +114,6 @@ impl EGraph {
94114
fn bindings(_py: Python, m: &PyModule) -> PyResult<()> {
95115
m.add_class::<EGraph>()?;
96116
m.add_class::<EggSmolError>()?;
117+
m.add_class::<WrappedVariant>()?;
97118
Ok(())
98119
}

0 commit comments

Comments
 (0)