@@ -2,21 +2,22 @@ use egglog_ast::span::Span;
22use std:: sync:: Arc ;
33
44use egglog:: {
5- CommandOutput , EGraph , Error , Term , TermDag , TypeError , UserDefinedCommand ,
5+ CostModelExtractorBuilder , EGraph ,
66 ast:: * ,
7- extract:: { CostModel , DefaultCost , Extractor , TreeAdditiveCostModel } ,
7+ extract:: { CostModel , DefaultCost , TreeAdditiveCostModel } ,
88 util:: FreshGen ,
99} ;
10- use log:: log_enabled;
1110
1211pub fn add_set_cost ( egraph : & mut EGraph ) {
1312 egraph
1413 . parser
1514 . add_command_macro ( Arc :: new ( SetCostDeclarations ) ) ;
1615 egraph. parser . add_action_macro ( Arc :: new ( SetCost ) ) ;
17- egraph
18- . add_command ( "extract" . into ( ) , Arc :: new ( CustomExtract ) )
19- . unwrap ( ) ;
16+ egraph. register_extractor (
17+ "dynamic-cost" ,
18+ Arc :: new ( CostModelExtractorBuilder :: new ( DynamicCostModel ) ) ,
19+ ) ;
20+ egraph. set_default_extractor ( "dynamic-cost" ) . unwrap ( ) ;
2021}
2122
2223struct SetCost ;
@@ -204,63 +205,3 @@ impl CostModel<DefaultCost> for DynamicCostModel {
204205 }
205206 }
206207}
207-
208- struct CustomExtract ;
209-
210- impl UserDefinedCommand for CustomExtract {
211- fn update (
212- & self ,
213- egraph : & mut EGraph ,
214- args : & [ Expr ] ,
215- ) -> Result < Option < CommandOutput > , egglog:: Error > {
216- assert ! ( args. len( ) <= 2 ) ;
217- let ( sort, value) = egraph. eval_expr ( & args[ 0 ] ) ?;
218- let n = args. get ( 1 ) . map ( |arg| egraph. eval_expr ( arg) ) . transpose ( ) ?;
219- let n = if let Some ( nv) = n {
220- // TODO: egglog does not yet support u64
221- if nv. 0 . name ( ) != "i64" {
222- let i64sort = egraph. get_arcsort_by ( |s| s. name ( ) == "i64" ) ;
223- return Err ( Error :: TypeError ( TypeError :: Mismatch {
224- expr : args[ 1 ] . clone ( ) ,
225- expected : i64sort,
226- actual : nv. 0 ,
227- } ) ) ;
228- }
229- egraph. value_to_base :: < i64 > ( nv. 1 )
230- } else {
231- 0
232- } ;
233-
234- let mut termdag = TermDag :: default ( ) ;
235-
236- let extractor = Extractor :: compute_costs_from_rootsorts (
237- Some ( vec ! [ sort. clone( ) ] ) ,
238- egraph,
239- DynamicCostModel ,
240- ) ;
241- if n == 0 {
242- if let Some ( ( cost, term) ) = extractor. extract_best ( egraph, & mut termdag, value) {
243- if log_enabled ! ( log:: Level :: Info ) {
244- log:: info!( "extracted with cost {cost}: {}" , termdag. to_string( & term) ) ;
245- }
246- Ok ( Some ( CommandOutput :: ExtractBest ( termdag, cost, term) ) )
247- } else {
248- Err ( Error :: ExtractError (
249- "Unable to find any valid extraction (likely due to subsume or delete)"
250- . to_string ( ) ,
251- ) )
252- }
253- } else {
254- if n < 0 {
255- panic ! ( "Cannot extract negative number of variants" ) ;
256- }
257- let terms: Vec < Term > = extractor
258- . extract_variants ( egraph, & mut termdag, value, n as usize )
259- . iter ( )
260- . map ( |e| e. 1 . clone ( ) )
261- . collect ( ) ;
262- log:: info!( "extracted variants:" ) ;
263- Ok ( Some ( CommandOutput :: ExtractVariants ( termdag, terms) ) )
264- }
265- }
266- }
0 commit comments