Skip to content

Commit f5509d6

Browse files
Remove custom extract command and register cost model instead
1 parent fb6401e commit f5509d6

File tree

3 files changed

+19
-79
lines changed

3 files changed

+19
-79
lines changed

Cargo.lock

Lines changed: 9 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@ default = ["bin"]
1717
bin = ["egglog/bin"]
1818

1919
[dependencies]
20-
egglog = { git = "https://github.com/egraphs-good/egglog.git", default-features = false, rev = "33f7994" }
21-
egglog-ast = { git = "https://github.com/egraphs-good/egglog.git", default-features = false, rev = "33f7994" }
22-
egglog-reports = { git = "https://github.com/egraphs-good/egglog.git", default-features = false, rev = "33f7994" }
23-
20+
egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "extract-cost-model", default-features = false }
21+
egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "extract-cost-model", default-features = false }
22+
egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "extract-cost-model", default-features = false }
2423
num = "0.4.3"
2524
lazy_static = "1.4"
2625
log = "0.4"

src/set_cost.rs

Lines changed: 7 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,22 @@ use egglog_ast::span::Span;
22
use std::sync::Arc;
33

44
use egglog::{
5-
CommandOutput, EGraph, Error, Term, TermDag, TypeError, UserDefinedCommand,
5+
CostModelExtractorBuilder, EGraph,
66
ast::*,
7-
extract::{CostModel, DefaultCost, Extractor, TreeAdditiveCostModel},
7+
extract::{CostModel, DefaultCost, TreeAdditiveCostModel},
88
util::FreshGen,
99
};
10-
use log::log_enabled;
1110

1211
pub fn add_set_cost(egraph: &mut EGraph) {
1312
egraph
1413
.parser
1514
.add_command_macro(Arc::new(SetCostDeclarations));
1615
egraph.parser.add_action_macro(Arc::new(SetCost));
17-
egraph
18-
.add_command("extract".into(), Arc::new(CustomExtract))
19-
.unwrap();
16+
egraph.register_extractor(
17+
"dynamic-cost",
18+
Arc::new(CostModelExtractorBuilder::new(DynamicCostModel)),
19+
);
20+
egraph.set_default_extractor("dynamic-cost").unwrap();
2021
}
2122

2223
struct SetCost;
@@ -204,63 +205,3 @@ impl CostModel<DefaultCost> for DynamicCostModel {
204205
}
205206
}
206207
}
207-
208-
struct CustomExtract;
209-
210-
impl UserDefinedCommand for CustomExtract {
211-
fn update(
212-
&self,
213-
egraph: &mut EGraph,
214-
args: &[Expr],
215-
) -> Result<Option<CommandOutput>, egglog::Error> {
216-
assert!(args.len() <= 2);
217-
let (sort, value) = egraph.eval_expr(&args[0])?;
218-
let n = args.get(1).map(|arg| egraph.eval_expr(arg)).transpose()?;
219-
let n = if let Some(nv) = n {
220-
// TODO: egglog does not yet support u64
221-
if nv.0.name() != "i64" {
222-
let i64sort = egraph.get_arcsort_by(|s| s.name() == "i64");
223-
return Err(Error::TypeError(TypeError::Mismatch {
224-
expr: args[1].clone(),
225-
expected: i64sort,
226-
actual: nv.0,
227-
}));
228-
}
229-
egraph.value_to_base::<i64>(nv.1)
230-
} else {
231-
0
232-
};
233-
234-
let mut termdag = TermDag::default();
235-
236-
let extractor = Extractor::compute_costs_from_rootsorts(
237-
Some(vec![sort.clone()]),
238-
egraph,
239-
DynamicCostModel,
240-
);
241-
if n == 0 {
242-
if let Some((cost, term)) = extractor.extract_best(egraph, &mut termdag, value) {
243-
if log_enabled!(log::Level::Info) {
244-
log::info!("extracted with cost {cost}: {}", termdag.to_string(&term));
245-
}
246-
Ok(Some(CommandOutput::ExtractBest(termdag, cost, term)))
247-
} else {
248-
Err(Error::ExtractError(
249-
"Unable to find any valid extraction (likely due to subsume or delete)"
250-
.to_string(),
251-
))
252-
}
253-
} else {
254-
if n < 0 {
255-
panic!("Cannot extract negative number of variants");
256-
}
257-
let terms: Vec<Term> = extractor
258-
.extract_variants(egraph, &mut termdag, value, n as usize)
259-
.iter()
260-
.map(|e| e.1.clone())
261-
.collect();
262-
log::info!("extracted variants:");
263-
Ok(Some(CommandOutput::ExtractVariants(termdag, terms)))
264-
}
265-
}
266-
}

0 commit comments

Comments
 (0)