@@ -6,17 +6,11 @@ use crate::{conversions::Term, egraph::EGraph, egraph::Value, termdag::TermDag};
66
77#[ derive( Debug ) ]
88// We have to store the result, since the cost model does not return errors
9- struct Cost ( PyResult < Py < PyAny > > ) ;
9+ struct Cost ( Py < PyAny > ) ;
1010
1111impl Ord for Cost {
1212 fn cmp ( & self , other : & Self ) -> Ordering {
13- // Always order errors as smallest cost so they are prefered
14- match ( & self . 0 , & other. 0 ) {
15- ( Err ( _) , Err ( _) ) => Ordering :: Equal ,
16- ( Err ( _) , _) => Ordering :: Less ,
17- ( _, Err ( _) ) => Ordering :: Greater ,
18- ( Ok ( l) , Ok ( r) ) => Python :: attach ( |py| l. bind ( py) . compare ( r. bind ( py) ) . unwrap ( ) ) ,
19- }
13+ Python :: attach ( |py| self . 0 . bind ( py) . compare ( other. 0 . bind ( py) ) . unwrap ( ) )
2014 }
2115}
2216
@@ -28,26 +22,15 @@ impl PartialOrd for Cost {
2822
2923impl PartialEq for Cost {
3024 fn eq ( & self , other : & Self ) -> bool {
31- // errors are equal
32- match ( & self . 0 , & other. 0 ) {
33- ( Err ( _) , Err ( _) ) => true ,
34- ( Err ( _) , _) => false ,
35- ( _, Err ( _) ) => false ,
36- ( Ok ( l) , Ok ( r) ) => Python :: attach ( |py| l. bind ( py) . eq ( r. bind ( py) ) ) . unwrap ( ) ,
37- }
25+ Python :: attach ( |py| self . 0 . bind ( py) . eq ( other. 0 . bind ( py) ) ) . unwrap ( )
3826 }
3927}
4028
4129impl Eq for Cost { }
4230
4331impl Clone for Cost {
4432 fn clone ( & self ) -> Self {
45- Python :: attach ( |py| {
46- Cost ( match & self . 0 {
47- Ok ( v) => Ok ( v. clone_ref ( py) ) ,
48- Err ( e) => Err ( e. clone_ref ( py) ) ,
49- } )
50- } )
33+ Python :: attach ( |py| Cost ( self . 0 . clone_ref ( py) ) )
5134 }
5235}
5336
@@ -118,13 +101,15 @@ impl Clone for CostModel {
118101impl egglog:: extract:: CostModel < Cost > for CostModel {
119102 fn fold ( & self , head : & str , children_cost : & [ Cost ] , head_cost : Cost ) -> Cost {
120103 Cost ( Python :: attach ( |py| {
121- let head_cost = head_cost. 0 . map ( |v| v . clone_ref ( py) ) ? ;
104+ let head_cost = head_cost. 0 . clone_ref ( py) ;
122105 let children_cost = children_cost
123106 . into_iter ( )
124107 . cloned ( )
125- . map ( |c| c. 0 . map ( |v| v. clone_ref ( py) ) )
126- . collect :: < PyResult < Vec < _ > > > ( ) ?;
127- self . fold . call1 ( py, ( head, head_cost, children_cost) )
108+ . map ( |c| c. 0 . clone_ref ( py) )
109+ . collect :: < Vec < _ > > ( ) ;
110+ self . fold
111+ . call1 ( py, ( head, head_cost, children_cost) )
112+ . unwrap ( )
128113 } ) )
129114 }
130115
@@ -140,7 +125,7 @@ impl egglog::extract::CostModel<Cost> for CostModel {
140125 // this is not needed because the only thing we can do with the output is look up an analysis
141126 // which we can also do with the original function
142127 values. pop ( ) . unwrap ( ) ;
143- Cost ( self . enode_cost . call1 ( py, ( func. name ( ) , values) ) )
128+ Cost ( self . enode_cost . call1 ( py, ( func. name ( ) , values) ) . unwrap ( ) )
144129 } )
145130 }
146131
@@ -155,10 +140,11 @@ impl egglog::extract::CostModel<Cost> for CostModel {
155140 let element_costs = element_costs
156141 . into_iter ( )
157142 . cloned ( )
158- . map ( |c| c. 0 . map ( |v| v . clone_ref ( py) ) )
159- . collect :: < PyResult < Vec < _ > > > ( ) ? ;
143+ . map ( |c| c. 0 . clone_ref ( py) )
144+ . collect :: < Vec < _ > > ( ) ;
160145 self . container_cost
161146 . call1 ( py, ( sort. name ( ) , Value ( value) , element_costs) )
147+ . unwrap ( )
162148 } ) )
163149 }
164150
@@ -169,7 +155,13 @@ impl egglog::extract::CostModel<Cost> for CostModel {
169155 sort : & egglog:: ArcSort ,
170156 value : egglog:: Value ,
171157 ) -> Cost {
172- Python :: attach ( |py| Cost ( self . base_value_cost . call1 ( py, ( sort. name ( ) , Value ( value) ) ) ) )
158+ Python :: attach ( |py| {
159+ Cost (
160+ self . base_value_cost
161+ . call1 ( py, ( sort. name ( ) , Value ( value) ) )
162+ . unwrap ( ) ,
163+ )
164+ } )
173165 }
174166}
175167
@@ -233,7 +225,7 @@ impl Extractor {
233225 . 0
234226 . extract_best_with_sort ( & egraph. egraph , & mut termdag. 0 , value. 0 , sort. clone ( ) )
235227 . ok_or ( PyValueError :: new_err ( "Unextractable root" . to_string ( ) ) ) ?;
236- Ok ( ( cost. 0 ? . clone_ref ( py) , term. into ( ) ) )
228+ Ok ( ( cost. 0 . clone_ref ( py) , term. into ( ) ) )
237229 }
238230
239231 /// Extract variants of an e-class.
@@ -260,9 +252,9 @@ impl Extractor {
260252 nvariants,
261253 sort. clone ( ) ,
262254 ) ;
263- variants
255+ Ok ( variants
264256 . into_iter ( )
265- . map ( |( cost, term) | ( cost. 0 . map ( |c| ( c . clone_ref ( py) , term. into ( ) ) ) ) )
266- . collect ( )
257+ . map ( |( cost, term) | ( cost. 0 . clone_ref ( py) , term. into ( ) ) )
258+ . collect ( ) )
267259 }
268260}
0 commit comments