Skip to content

Commit 02acfae

Browse files
Add run command
1 parent 13e747b commit 02acfae

File tree

4 files changed

+81
-1
lines changed

4 files changed

+81
-1
lines changed

python/egg_smol/bindings.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import timedelta
12
from typing import Optional
23

34
from typing_extensions import final
@@ -12,6 +13,7 @@ class EGraph:
1213
def declare_function(self, decl: FunctionDecl) -> None: ...
1314
def define(self, name: str, expr: Expr, cost: Optional[int] = None) -> None: ...
1415
def add_rewrite(self, rewrite: Rewrite) -> str: ...
16+
def run_rules(self, limit: int) -> tuple[timedelta, timedelta, timedelta]: ...
1517

1618
@final
1719
class EggSmolError(Exception):

python/tests/test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import datetime
2+
13
import pytest
24
from egg_smol.bindings import EggSmolError, EGraph
35
from egg_smol.bindings_py import *
@@ -92,3 +94,35 @@ def test_rewrite(self):
9294
)
9395
)
9496
assert isinstance(name, str)
97+
98+
def test_run_rules(self):
99+
egraph = EGraph()
100+
egraph.declare_sort("Math")
101+
egraph.declare_constructor(Variant("Add", ["Math", "Math"]), "Math")
102+
egraph.add_rewrite(
103+
Rewrite(
104+
Call(
105+
"Add",
106+
[
107+
Var("a"),
108+
Var("b"),
109+
],
110+
),
111+
Call(
112+
"Add",
113+
[
114+
Var("b"),
115+
Var("a"),
116+
],
117+
),
118+
)
119+
)
120+
start_time = datetime.datetime.now()
121+
searching, applying, rebuilding = egraph.run_rules(10)
122+
end_time = datetime.datetime.now()
123+
assert isinstance(searching, datetime.timedelta)
124+
assert isinstance(applying, datetime.timedelta)
125+
assert isinstance(rebuilding, datetime.timedelta)
126+
total_measured_time = searching + applying + rebuilding
127+
# Verify less than the total time (which includes time spent in Python).
128+
assert total_measured_time < (end_time - start_time)

src/conversions.rs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
use std::time::Duration;
2+
13
// Create wrappers around input types so that convert from pyobjects to them
24
// and then from them to the egg_smol types
35
//
46
// Converts from Python classes we define in pure python so we can use dataclasses
57
// to represent the input types
68
// TODO: Copy strings of these from egg-smol... Maybe actually wrap those isntead.
7-
use pyo3::prelude::*;
9+
use pyo3::{ffi::PyDateTime_Delta, prelude::*, types::PyDelta};
810

911
// Execute the block and wrap the error in a type error
1012
fn wrap_error<T>(tp: &str, obj: &'_ PyAny, block: impl FnOnce() -> PyResult<T>) -> PyResult<T> {
@@ -237,3 +239,30 @@ impl From<WrappedFact> for egg_smol::ast::Fact {
237239
other.0
238240
}
239241
}
242+
243+
// Wrapped version of Duration
244+
// Converts from a rust duration to a python timedelta
245+
pub struct WrappedDuration(Duration);
246+
247+
impl From<Duration> for WrappedDuration {
248+
fn from(other: Duration) -> Self {
249+
WrappedDuration(other)
250+
}
251+
}
252+
253+
impl IntoPy<PyObject> for WrappedDuration {
254+
fn into_py(self, py: Python<'_>) -> PyObject {
255+
let d = self.0;
256+
PyDelta::new(
257+
py,
258+
0,
259+
0,
260+
d.as_millis()
261+
.try_into()
262+
.expect("Failed to convert miliseconds to int32 when converting duration"),
263+
true,
264+
)
265+
.expect("Failed to contruct timedelta")
266+
.into()
267+
}
268+
}

src/lib.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
mod conversions;
22
mod error;
3+
use std::time::Duration;
4+
35
use conversions::*;
46
use error::*;
57
use pyo3::prelude::*;
@@ -22,6 +24,19 @@ impl EGraph {
2224
}
2325
}
2426

27+
/// Run the rules on the egraph until it reaches a fixpoint, specifying the max number of iterations.
28+
/// Returns a tuple of the total time spen searching, applying, and rebuilding.
29+
#[pyo3(text_signature = "($self, limit)")]
30+
fn run_rules(
31+
&mut self,
32+
limit: usize,
33+
) -> EggResult<(WrappedDuration, WrappedDuration, WrappedDuration)> {
34+
let [search, apply, rebuild] = self.egraph.run_rules(limit);
35+
// Print all the timings
36+
println!("Timings: {:?}, {:?}, {:?}", search, apply, rebuild);
37+
Ok((search.into(), apply.into(), rebuild.into()))
38+
}
39+
2540
/// Define a rewrite rule, returning the name of the rule
2641
#[pyo3(text_signature = "($self, rewrite)")]
2742
fn add_rewrite(&mut self, rewrite: WrappedRewrite) -> EggResult<String> {

0 commit comments

Comments
 (0)