Skip to content

Commit 38fd1f6

Browse files
Fix threading issue now that things are parallel
1 parent 6aa2d11 commit 38fd1f6

File tree

2 files changed

+128
-156
lines changed

2 files changed

+128
-156
lines changed

src/egraph.rs

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,17 @@ impl EGraph {
6161
/// Returns a list of strings representing the output.
6262
/// An EggSmolError is raised if there is problem parsing or executing.
6363
#[pyo3(signature=(*commands))]
64-
fn run_program(&mut self, commands: Vec<Command>) -> EggResult<Vec<String>> {
64+
fn run_program(&mut self, py: Python<'_>, commands: Vec<Command>) -> EggResult<Vec<String>> {
6565
let commands: Vec<egglog::ast::Command> = commands.into_iter().map(|x| x.into()).collect();
6666
let mut cmds_str = String::new();
6767
for cmd in &commands {
6868
cmds_str = cmds_str + &cmd.to_string() + "\n";
6969
}
7070
info!("Running commands:\n{}", cmds_str);
71-
72-
let res = self.egraph.run_program(commands).map_err(|e| {
73-
WrappedError::Egglog(e, "\nWhen running commands:\n".to_string() + &cmds_str)
71+
let res = py.allow_threads(|| {
72+
self.egraph.run_program(commands).map_err(|e| {
73+
WrappedError::Egglog(e, "\nWhen running commands:\n".to_string() + &cmds_str)
74+
})
7475
});
7576
if res.is_ok() {
7677
if let Some(cmds) = &mut self.cmds {
@@ -115,22 +116,25 @@ impl EGraph {
115116
)]
116117
fn serialize(
117118
&mut self,
119+
py: Python<'_>,
118120
root_eclasses: Vec<Expr>,
119121
max_functions: Option<usize>,
120122
max_calls_per_function: Option<usize>,
121123
include_temporary_functions: bool,
122124
) -> SerializedEGraph {
123-
let root_eclasses: Vec<_> = root_eclasses
124-
.into_iter()
125-
.map(|x| self.egraph.eval_expr(&egglog::ast::Expr::from(x)).unwrap())
126-
.collect();
127-
SerializedEGraph {
128-
egraph: self.egraph.serialize(SerializeConfig {
129-
max_functions,
130-
max_calls_per_function,
131-
include_temporary_functions,
132-
root_eclasses,
133-
}),
134-
}
125+
py.allow_threads(|| {
126+
let root_eclasses: Vec<_> = root_eclasses
127+
.into_iter()
128+
.map(|x| self.egraph.eval_expr(&egglog::ast::Expr::from(x)).unwrap())
129+
.collect();
130+
SerializedEGraph {
131+
egraph: self.egraph.serialize(SerializeConfig {
132+
max_functions,
133+
max_calls_per_function,
134+
include_temporary_functions,
135+
root_eclasses,
136+
}),
137+
}
138+
})
135139
}
136140
}

0 commit comments

Comments
 (0)