Skip to content

Commit a705104

Browse files
Makes conversions more ergonomic
1 parent a16bf04 commit a705104

File tree

2 files changed

+33
-24
lines changed

2 files changed

+33
-24
lines changed

src/conversions.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Create wrappers around input types so that convert from pyobjects to them
2+
// and then from them to the egg_smol types
3+
//
4+
// Converts from Python classes we define in pure python so we can use dataclasses
5+
// to represent the input types
6+
use pyo3::prelude::*;
7+
8+
pub struct WrappedVariant(egg_smol::ast::Variant);
9+
10+
impl FromPyObject<'_> for WrappedVariant {
11+
fn extract(obj: &'_ PyAny) -> PyResult<Self> {
12+
Ok(WrappedVariant(egg_smol::ast::Variant {
13+
name: obj.getattr("name")?.extract::<String>()?.into(),
14+
cost: obj.getattr("cost")?.extract()?,
15+
types: obj
16+
.getattr("types")?
17+
.extract::<Vec<String>>()?
18+
.into_iter()
19+
.map(|x| x.into())
20+
.collect(),
21+
}))
22+
}
23+
}
24+
25+
impl From<WrappedVariant> for egg_smol::ast::Variant {
26+
fn from(other: WrappedVariant) -> Self {
27+
other.0
28+
}
29+
}

src/lib.rs

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
mod conversions;
12
mod error;
3+
use conversions::*;
24
use error::*;
35
use pyo3::prelude::*;
46

@@ -11,24 +13,6 @@ struct EGraph {
1113
egraph: egg_smol::EGraph,
1214
}
1315

14-
// Convert a Python Variant object into a rust variable, by getting the attributes
15-
fn get_variant(obj: &PyAny) -> PyResult<egg_smol::ast::Variant> {
16-
// TODO: Is there a way to do this more automatically?
17-
Ok(egg_smol::ast::Variant {
18-
name: obj
19-
.getattr(pyo3::intern!(obj.py(), "name"))?
20-
.extract::<String>()?
21-
.into(),
22-
cost: obj.getattr(pyo3::intern!(obj.py(), "cost"))?.extract()?,
23-
types: obj
24-
.getattr(pyo3::intern!(obj.py(), "types"))?
25-
.extract::<Vec<String>>()?
26-
.into_iter()
27-
.map(|x| x.into())
28-
.collect(),
29-
})
30-
}
31-
3216
#[pymethods]
3317
impl EGraph {
3418
#[new]
@@ -53,12 +37,8 @@ impl EGraph {
5337
/// --
5438
///
5539
/// Declare a new datatype constructor.
56-
fn declare_constructor(
57-
&mut self,
58-
#[pyo3(from_py_with = "get_variant")] variant: egg_smol::ast::Variant,
59-
sort: &str,
60-
) -> EggResult<()> {
61-
self.egraph.declare_constructor(variant, sort)?;
40+
fn declare_constructor(&mut self, variant: WrappedVariant, sort: &str) -> EggResult<()> {
41+
self.egraph.declare_constructor(variant.into(), sort)?;
6242
Ok({})
6343
}
6444

0 commit comments

Comments
 (0)