Skip to content

Commit a0c37d6

Browse files
Add check_fact
1 parent 02acfae commit a0c37d6

File tree

3 files changed

+217
-4
lines changed

3 files changed

+217
-4
lines changed

python/egg_smol/bindings.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ from typing import Optional
33

44
from typing_extensions import final
55

6-
from .bindings_py import Expr, FunctionDecl, Rewrite, Variant
6+
from .bindings_py import Expr, Fact_, FunctionDecl, Rewrite, Variant
77

88
@final
99
class EGraph:
@@ -14,6 +14,7 @@ class EGraph:
1414
def define(self, name: str, expr: Expr, cost: Optional[int] = None) -> None: ...
1515
def add_rewrite(self, rewrite: Rewrite) -> str: ...
1616
def run_rules(self, limit: int) -> tuple[timedelta, timedelta, timedelta]: ...
17+
def check_fact(self, fact: Fact_) -> None: ...
1718

1819
@final
1920
class EggSmolError(Exception):

python/tests/test.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def test_parse_and_run_program_exception(self):
2222
):
2323
egraph.parse_and_run_program(program)
2424

25+
# These examples are from eqsat-basic
2526
def test_datatype(self):
2627
egraph = EGraph()
2728
egraph.declare_sort("Math")
@@ -126,3 +127,210 @@ def test_run_rules(self):
126127
total_measured_time = searching + applying + rebuilding
127128
# Verify less than the total time (which includes time spent in Python).
128129
assert total_measured_time < (end_time - start_time)
130+
131+
def test_check_fact(self):
132+
egraph = EGraph()
133+
egraph.declare_sort("Math")
134+
egraph.declare_constructor(Variant("Num", ["i64"]), "Math")
135+
egraph.declare_constructor(Variant("Var", ["String"]), "Math")
136+
egraph.declare_constructor(Variant("Add", ["Math", "Math"]), "Math")
137+
egraph.declare_constructor(Variant("Mul", ["Math", "Math"]), "Math")
138+
139+
# (define expr1 (Mul (Num 2) (Add (Var "x") (Num 3))))
140+
egraph.define(
141+
"expr1",
142+
Call(
143+
"Mul",
144+
[
145+
Call(
146+
"Num",
147+
[
148+
Lit(Int(2)),
149+
],
150+
),
151+
Call(
152+
"Add",
153+
[
154+
Call(
155+
"Var",
156+
[
157+
Lit(String("x")),
158+
],
159+
),
160+
Call(
161+
"Num",
162+
[
163+
Lit(Int(3)),
164+
],
165+
),
166+
],
167+
),
168+
],
169+
),
170+
)
171+
# (define expr2 (Add (Num 6) (Mul (Num 2) (Var "x"))))
172+
egraph.define(
173+
"expr2",
174+
Call(
175+
"Add",
176+
[
177+
Call(
178+
"Num",
179+
[
180+
Lit(Int(6)),
181+
],
182+
),
183+
Call(
184+
"Mul",
185+
[
186+
Call(
187+
"Num",
188+
[
189+
Lit(Int(2)),
190+
],
191+
),
192+
Call(
193+
"Var",
194+
[
195+
Lit(String("x")),
196+
],
197+
),
198+
],
199+
),
200+
],
201+
),
202+
)
203+
# (rewrite (Add a b)
204+
# (Add b a))
205+
egraph.add_rewrite(
206+
Rewrite(
207+
Call(
208+
"Add",
209+
[
210+
Var("a"),
211+
Var("b"),
212+
],
213+
),
214+
Call(
215+
"Add",
216+
[
217+
Var("b"),
218+
Var("a"),
219+
],
220+
),
221+
)
222+
)
223+
# (rewrite (Mul a (Add b c))
224+
# (Add (Mul a b) (Mul a c)))
225+
egraph.add_rewrite(
226+
Rewrite(
227+
Call(
228+
"Mul",
229+
[
230+
Var("a"),
231+
Call(
232+
"Add",
233+
[
234+
Var("b"),
235+
Var("c"),
236+
],
237+
),
238+
],
239+
),
240+
Call(
241+
"Add",
242+
[
243+
Call(
244+
"Mul",
245+
[
246+
Var("a"),
247+
Var("b"),
248+
],
249+
),
250+
Call(
251+
"Mul",
252+
[
253+
Var("a"),
254+
Var("c"),
255+
],
256+
),
257+
],
258+
),
259+
)
260+
)
261+
262+
# (rewrite (Add (Num a) (Num b))
263+
# (Num (+ a b)))
264+
lhs = Call(
265+
"Add",
266+
[
267+
Call(
268+
"Num",
269+
[
270+
Var("a"),
271+
],
272+
),
273+
Call(
274+
"Num",
275+
[
276+
Var("b"),
277+
],
278+
),
279+
],
280+
)
281+
rhs = Call(
282+
"Num",
283+
[
284+
Call(
285+
"+",
286+
[
287+
Var("a"),
288+
Var("b"),
289+
],
290+
)
291+
],
292+
)
293+
egraph.add_rewrite(Rewrite(lhs, rhs))
294+
295+
# (rewrite (Mul (Num a) (Num b))
296+
# (Num (* a b)))
297+
lhs = Call(
298+
"Mul",
299+
[
300+
Call(
301+
"Num",
302+
[
303+
Var("a"),
304+
],
305+
),
306+
Call(
307+
"Num",
308+
[
309+
Var("b"),
310+
],
311+
),
312+
],
313+
)
314+
rhs = Call(
315+
"Num",
316+
[
317+
Call(
318+
"*",
319+
[
320+
Var("a"),
321+
Var("b"),
322+
],
323+
)
324+
],
325+
)
326+
egraph.add_rewrite(Rewrite(lhs, rhs))
327+
328+
egraph.run_rules(10)
329+
egraph.check_fact(
330+
Eq(
331+
[
332+
Var("expr1"),
333+
Var("expr2"),
334+
]
335+
)
336+
)

src/lib.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
mod conversions;
22
mod error;
3-
use std::time::Duration;
43

54
use conversions::*;
65
use error::*;
@@ -24,6 +23,13 @@ impl EGraph {
2423
}
2524
}
2625

26+
/// Check that a fact is true in the egraph.
27+
#[pyo3(text_signature = "($self, fact)")]
28+
fn check_fact(&mut self, fact: WrappedFact) -> EggResult<()> {
29+
self.egraph.check_fact(&fact.into())?;
30+
Ok({})
31+
}
32+
2733
/// Run the rules on the egraph until it reaches a fixpoint, specifying the max number of iterations.
2834
/// Returns a tuple of the total time spen searching, applying, and rebuilding.
2935
#[pyo3(text_signature = "($self, limit)")]
@@ -32,8 +38,6 @@ impl EGraph {
3238
limit: usize,
3339
) -> EggResult<(WrappedDuration, WrappedDuration, WrappedDuration)> {
3440
let [search, apply, rebuild] = self.egraph.run_rules(limit);
35-
// Print all the timings
36-
println!("Timings: {:?}, {:?}, {:?}", search, apply, rebuild);
3741
Ok((search.into(), apply.into(), rebuild.into()))
3842
}
3943

0 commit comments

Comments
 (0)