Skip to content

Commit 13e747b

Browse files
Add rewrite
1 parent 26bc7f9 commit 13e747b

File tree

5 files changed

+116
-2
lines changed

5 files changed

+116
-2
lines changed

python/egg_smol/bindings.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ from typing import Optional
22

33
from typing_extensions import final
44

5-
from .bindings_py import Expr, FunctionDecl, Variant
5+
from .bindings_py import Expr, FunctionDecl, Rewrite, Variant
66

77
@final
88
class EGraph:
@@ -11,6 +11,7 @@ class EGraph:
1111
def declare_sort(self, name: str) -> None: ...
1212
def declare_function(self, decl: FunctionDecl) -> None: ...
1313
def define(self, name: str, expr: Expr, cost: Optional[int] = None) -> None: ...
14+
def add_rewrite(self, rewrite: Rewrite) -> str: ...
1415

1516
@final
1617
class EggSmolError(Exception):

python/egg_smol/bindings_py.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# TODO: Figure out what these modules should be called
22
from __future__ import annotations
33

4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
from typing import Optional, Union
66

77

@@ -62,3 +62,23 @@ class Unit:
6262

6363

6464
Literal = Union[Int, String, Unit]
65+
66+
67+
@dataclass(frozen=True)
68+
class Rewrite:
69+
lhs: Expr
70+
rhs: Expr
71+
conditions: list[Fact_] = field(default_factory=list)
72+
73+
74+
@dataclass(frozen=True)
75+
class Fact:
76+
expr: Expr
77+
78+
79+
@dataclass
80+
class Eq:
81+
exprs: list[Expr]
82+
83+
84+
Fact_ = Union[Fact, Eq]

python/tests/test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,27 @@ def test_define(self):
6868
],
6969
),
7070
)
71+
72+
def test_rewrite(self):
73+
egraph = EGraph()
74+
egraph.declare_sort("Math")
75+
egraph.declare_constructor(Variant("Add", ["Math", "Math"]), "Math")
76+
name = egraph.add_rewrite(
77+
Rewrite(
78+
Call(
79+
"Add",
80+
[
81+
Var("a"),
82+
Var("b"),
83+
],
84+
),
85+
Call(
86+
"Add",
87+
[
88+
Var("b"),
89+
Var("a"),
90+
],
91+
),
92+
)
93+
)
94+
assert isinstance(name, str)

src/conversions.rs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
//
44
// Converts from Python classes we define in pure python so we can use dataclasses
55
// to represent the input types
6+
// TODO: Copy strings of these from egg-smol... Maybe actually wrap those isntead.
67
use pyo3::prelude::*;
78

89
// Execute the block and wrap the error in a type error
@@ -175,3 +176,64 @@ impl From<WrappedLiteral> for egg_smol::ast::Literal {
175176
other.0
176177
}
177178
}
179+
180+
// Wrapped version of Rewrite
181+
pub struct WrappedRewrite(egg_smol::ast::Rewrite);
182+
183+
impl FromPyObject<'_> for WrappedRewrite {
184+
fn extract(obj: &'_ PyAny) -> PyResult<Self> {
185+
wrap_error("Rewrite", obj, || {
186+
Ok(WrappedRewrite(egg_smol::ast::Rewrite {
187+
lhs: obj.getattr("lhs")?.extract::<WrappedExpr>()?.into(),
188+
rhs: obj.getattr("rhs")?.extract::<WrappedExpr>()?.into(),
189+
conditions: obj
190+
.getattr("conditions")?
191+
.extract::<Vec<WrappedFact>>()?
192+
.into_iter()
193+
.map(|x| x.into())
194+
.collect(),
195+
}))
196+
})
197+
}
198+
}
199+
200+
impl From<WrappedRewrite> for egg_smol::ast::Rewrite {
201+
fn from(other: WrappedRewrite) -> Self {
202+
other.0
203+
}
204+
}
205+
206+
// Wrapped version of Fact
207+
pub struct WrappedFact(egg_smol::ast::Fact);
208+
209+
impl FromPyObject<'_> for WrappedFact {
210+
fn extract(obj: &'_ PyAny) -> PyResult<Self> {
211+
wrap_error("Fact", obj, || {
212+
extract_fact_eq(obj)
213+
.or_else(|_| extract_fact_fact(obj))
214+
.map(WrappedFact)
215+
})
216+
}
217+
}
218+
219+
fn extract_fact_eq(obj: &PyAny) -> PyResult<egg_smol::ast::Fact> {
220+
Ok(egg_smol::ast::Fact::Eq(
221+
obj.getattr("exprs")?
222+
.extract::<Vec<WrappedExpr>>()?
223+
.into_iter()
224+
.map(|x| x.into())
225+
.collect(),
226+
))
227+
}
228+
229+
fn extract_fact_fact(obj: &PyAny) -> PyResult<egg_smol::ast::Fact> {
230+
Ok(egg_smol::ast::Fact::Fact(
231+
obj.getattr("expr")?.extract::<WrappedExpr>()?.into(),
232+
))
233+
}
234+
235+
impl From<WrappedFact> for egg_smol::ast::Fact {
236+
fn from(other: WrappedFact) -> Self {
237+
other.0
238+
}
239+
}

src/lib.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ impl EGraph {
2222
}
2323
}
2424

25+
/// Define a rewrite rule, returning the name of the rule
26+
#[pyo3(text_signature = "($self, rewrite)")]
27+
fn add_rewrite(&mut self, rewrite: WrappedRewrite) -> EggResult<String> {
28+
let res = self.egraph.add_rewrite(rewrite.into())?;
29+
Ok(res.to_string())
30+
}
31+
2532
/// Define a new named value.
2633
#[pyo3(
2734
text_signature = "($self, name, expr, cost=None)",

0 commit comments

Comments
 (0)