Skip to content

Commit 9194ba8

Browse files
Update egglog dependency
Adds support for structured output added in egraphs-good/egglog#653. Adds a custom output for running schedules to return all reports and print them all.
1 parent 14755fd commit 9194ba8

File tree

6 files changed

+104
-59
lines changed

6 files changed

+104
-59
lines changed

Cargo.lock

Lines changed: 38 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ harness = false
88
name = "files"
99

1010
[dependencies]
11-
egglog = { git = "https://github.com/egraphs-good/egglog.git", rev = "5542549" }
11+
egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", rev = "afe962958b57b6101f3f48e980f910fa063f3f6e" }
1212

1313
num = "0.4.3"
1414
lazy_static = "1.4"

src/scheduling.rs

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
use std::{collections::HashMap, sync::Mutex};
1+
use std::{
2+
collections::HashMap,
3+
sync::{Arc, Mutex},
4+
};
25

36
use egglog::{
47
ast::{Expr, Fact, Facts, Literal, ParseError},
58
prelude::{query, run_ruleset},
69
scheduler::{Scheduler, SchedulerId},
7-
RunReport, UserDefinedCommand,
10+
CommandOutput, RunReport, UserDefinedCommand,
811
};
912
use lazy_static::lazy_static;
1013

@@ -52,9 +55,12 @@ impl ScheduleState {
5255
};
5356

5457
if let Expr::Var(_, ruleset) = arg {
55-
run_ruleset(egraph, ruleset.as_str())?;
56-
57-
return Ok(egraph.get_run_report().clone().unwrap());
58+
let output = run_ruleset(egraph, ruleset.as_str())?;
59+
assert!(output.len() == 1);
60+
if let CommandOutput::RunSchedule(report) = &output[0] {
61+
return Ok(report.clone());
62+
}
63+
panic!("Expected a RunSchedule, got {:?}", output[0]);
5864
}
5965

6066
let Expr::Call(span, head, exprs) = arg else {
@@ -193,13 +199,35 @@ impl ScheduleState {
193199
}
194200
}
195201

202+
#[derive(Debug, Clone)]
203+
struct RunExtendedScheduleOutput {
204+
reports: Vec<RunReport>,
205+
}
206+
207+
impl std::fmt::Display for RunExtendedScheduleOutput {
208+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209+
writeln!(f, "Ran schedules:")?;
210+
for report in &self.reports {
211+
writeln!(f, "{}", report)?;
212+
}
213+
Ok(())
214+
}
215+
}
216+
196217
impl UserDefinedCommand for RunExtendedSchedule {
197-
fn update(&self, egraph: &mut egglog::EGraph, args: &[Expr]) -> Result<(), egglog::Error> {
218+
fn update(
219+
&self,
220+
egraph: &mut egglog::EGraph,
221+
args: &[Expr],
222+
) -> Result<Option<CommandOutput>, egglog::Error> {
198223
let mut schedule = ScheduleState::new();
224+
let mut reports = Vec::new();
199225
for arg in args {
200-
schedule.run(egraph, arg)?;
226+
reports.push(schedule.run(egraph, arg)?);
201227
}
202-
Ok(())
228+
Ok(Some(CommandOutput::UserDefined(Arc::new(
229+
RunExtendedScheduleOutput { reports },
230+
))))
203231
}
204232
}
205233

src/set_cost.rs

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ use egglog::{
44
ast::*,
55
extract::{CostModel, DefaultCost, Extractor, TreeAdditiveCostModel},
66
util::FreshGen,
7-
EGraph, Error, Term, TermDag, TypeError, UserDefinedCommand,
7+
CommandOutput, EGraph, Error, Term, TermDag, TypeError, UserDefinedCommand,
88
};
9+
use log::log_enabled;
910

1011
pub fn add_set_cost(egraph: &mut EGraph) {
1112
egraph
@@ -208,7 +209,11 @@ impl CostModel<DefaultCost> for DynamicCostModel {
208209
struct CustomExtract;
209210

210211
impl UserDefinedCommand for CustomExtract {
211-
fn update(&self, egraph: &mut EGraph, args: &[Expr]) -> Result<(), Error> {
212+
fn update(
213+
&self,
214+
egraph: &mut EGraph,
215+
args: &[Expr],
216+
) -> Result<Option<CommandOutput>, egglog::Error> {
212217
assert!(args.len() <= 2);
213218
let (sort, value) = egraph.eval_expr(&args[0])?;
214219
let n = args.get(1).map(|arg| egraph.eval_expr(arg)).transpose()?;
@@ -236,26 +241,15 @@ impl UserDefinedCommand for CustomExtract {
236241
);
237242
if n == 0 {
238243
if let Some((cost, term)) = extractor.extract_best(egraph, &mut termdag, value) {
239-
// dont turn termdag into a string if we have messages disabled for performance reasons
240-
if egraph.messages_enabled() {
241-
let extracted = termdag.to_string(&term);
242-
log::info!("extracted with cost {cost}: {extracted}");
243-
egraph.print_msg(extracted);
244+
if log_enabled!(log::Level::Info) {
245+
log::info!("extracted with cost {cost}: {}", termdag.to_string(&term));
244246
}
245-
// TODO: egraph.extract_report is private
246-
// A future implementation should make a egglog_experimental::EGraph
247-
// that provides a similar set of methods and overrides its own extract_report.
248-
//
249-
// egraph.extract_report = Some(ExtractReport::Best {
250-
// termdag,
251-
// cost,
252-
// term,
253-
// });
247+
Ok(Some(CommandOutput::ExtractBest(termdag, cost, term)))
254248
} else {
255-
return Err(Error::ExtractError(
249+
Err(Error::ExtractError(
256250
"Unable to find any valid extraction (likely due to subsume or delete)"
257251
.to_string(),
258-
));
252+
))
259253
}
260254
} else {
261255
if n < 0 {
@@ -266,24 +260,8 @@ impl UserDefinedCommand for CustomExtract {
266260
.iter()
267261
.map(|e| e.1.clone())
268262
.collect();
269-
// Same as above, avoid turning termdag into a string if we have messages disabled for performance
270-
if egraph.messages_enabled() {
271-
log::info!("extracted variants:");
272-
let mut msg = String::default();
273-
msg += "(\n";
274-
assert!(!terms.is_empty());
275-
for expr in &terms {
276-
let str = termdag.to_string(expr);
277-
log::info!(" {str}");
278-
msg += &format!(" {str}\n");
279-
}
280-
msg += ")";
281-
egraph.print_msg(msg);
282-
}
283-
// TODO: Same as above. EGraph::extract_report is private.
284-
//
285-
// egraph.extract_report = Some(ExtractReport::Variants { termdag, terms });
263+
log::info!("extracted variants:");
264+
Ok(Some(CommandOutput::ExtractVariants(termdag, terms)))
286265
}
287-
Ok(())
288266
}
289267
}

tests/files.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ impl Run {
2222
);
2323
} else {
2424
let mut egraph = new_experimental_egraph();
25-
egraph.run_mode = RunMode::ShowDesugaredEgglog;
2625
let desugared_str = egraph
27-
.parse_and_run_program(self.path.to_str().map(String::from), &program)
26+
.resugar_program(self.path.to_str().map(String::from), &program)
2827
.unwrap()
2928
.join("\n");
3029

@@ -39,15 +38,19 @@ impl Run {
3938
fn test_program(&self, filename: Option<String>, program: &str, message: &str) {
4039
let mut egraph = new_experimental_egraph();
4140
match egraph.parse_and_run_program(filename, program) {
42-
Ok(msgs) => {
41+
Ok(outputs) => {
4342
if self.should_fail() {
4443
panic!(
4544
"Program should have failed! Instead, logged:\n {}",
46-
msgs.join("\n")
45+
outputs
46+
.iter()
47+
.map(|output| output.to_string())
48+
.collect::<Vec<_>>()
49+
.join("\n")
4750
);
4851
} else {
49-
for msg in msgs {
50-
println!(" {}", msg);
52+
for output in outputs {
53+
print!(" {}", output);
5154
}
5255
// Test graphviz dot generation
5356
let mut serialized = egraph.serialize(SerializeConfig {

tests/integration_test.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ fn test_extract() {
3535
.unwrap();
3636

3737
assert_eq!(result.len(), 5);
38-
assert_eq!(result[0], "(Add (Num 1) (Num 1))");
39-
assert_eq!(result[1], "(Num 2)");
40-
assert_eq!(result[2], "(Add (Num 1) (Num 1))");
41-
assert_eq!(result[3], "(Add (Num 1) (Num 1))");
42-
assert_eq!(result[4], "(Sub (Num 5) (Num 3))");
38+
assert_eq!(result[0].to_string(), "(Add (Num 1) (Num 1))\n");
39+
assert_eq!(result[1].to_string(), "(Num 2)\n");
40+
assert_eq!(result[2].to_string(), "(Add (Num 1) (Num 1))\n");
41+
assert_eq!(result[3].to_string(), "(Add (Num 1) (Num 1))\n");
42+
assert_eq!(result[4].to_string(), "(Sub (Num 5) (Num 3))\n");
4343
}
4444

4545
#[test]
@@ -74,7 +74,7 @@ fn test_extract_set_cost_decls() {
7474
"(with-dynamic-cost
7575
(datatype E (Add E E) (Sub E E :cost 200) (Num i64))
7676
(constructor Mul (E E) E :cost 100)
77-
(datatype*
77+
(datatype*
7878
(E2 (Add2 E2 E2) (Sub2 E2 E2 :cost 200) (List VecE2) (Num2 i64))
7979
(sort VecE2 (Vec E2))
8080
)

0 commit comments

Comments
 (0)