Skip to content

Commit fb0ffd0

Browse files
Remove error wrapping
1 parent 05934d9 commit fb0ffd0

File tree

1 file changed

+25
-33
lines changed

1 file changed

+25
-33
lines changed

src/extract.rs

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,11 @@ use crate::{conversions::Term, egraph::EGraph, egraph::Value, termdag::TermDag};
66

77
#[derive(Debug)]
88
// We have to store the result, since the cost model does not return errors
9-
struct Cost(PyResult<Py<PyAny>>);
9+
struct Cost(Py<PyAny>);
1010

1111
impl Ord for Cost {
1212
fn cmp(&self, other: &Self) -> Ordering {
13-
// Always order errors as smallest cost so they are prefered
14-
match (&self.0, &other.0) {
15-
(Err(_), Err(_)) => Ordering::Equal,
16-
(Err(_), _) => Ordering::Less,
17-
(_, Err(_)) => Ordering::Greater,
18-
(Ok(l), Ok(r)) => Python::attach(|py| l.bind(py).compare(r.bind(py)).unwrap()),
19-
}
13+
Python::attach(|py| self.0.bind(py).compare(other.0.bind(py)).unwrap())
2014
}
2115
}
2216

@@ -28,26 +22,15 @@ impl PartialOrd for Cost {
2822

2923
impl PartialEq for Cost {
3024
fn eq(&self, other: &Self) -> bool {
31-
// errors are equal
32-
match (&self.0, &other.0) {
33-
(Err(_), Err(_)) => true,
34-
(Err(_), _) => false,
35-
(_, Err(_)) => false,
36-
(Ok(l), Ok(r)) => Python::attach(|py| l.bind(py).eq(r.bind(py))).unwrap(),
37-
}
25+
Python::attach(|py| self.0.bind(py).eq(other.0.bind(py))).unwrap()
3826
}
3927
}
4028

4129
impl Eq for Cost {}
4230

4331
impl Clone for Cost {
4432
fn clone(&self) -> Self {
45-
Python::attach(|py| {
46-
Cost(match &self.0 {
47-
Ok(v) => Ok(v.clone_ref(py)),
48-
Err(e) => Err(e.clone_ref(py)),
49-
})
50-
})
33+
Python::attach(|py| Cost(self.0.clone_ref(py)))
5134
}
5235
}
5336

@@ -118,13 +101,15 @@ impl Clone for CostModel {
118101
impl egglog::extract::CostModel<Cost> for CostModel {
119102
fn fold(&self, head: &str, children_cost: &[Cost], head_cost: Cost) -> Cost {
120103
Cost(Python::attach(|py| {
121-
let head_cost = head_cost.0.map(|v| v.clone_ref(py))?;
104+
let head_cost = head_cost.0.clone_ref(py);
122105
let children_cost = children_cost
123106
.into_iter()
124107
.cloned()
125-
.map(|c| c.0.map(|v| v.clone_ref(py)))
126-
.collect::<PyResult<Vec<_>>>()?;
127-
self.fold.call1(py, (head, head_cost, children_cost))
108+
.map(|c| c.0.clone_ref(py))
109+
.collect::<Vec<_>>();
110+
self.fold
111+
.call1(py, (head, head_cost, children_cost))
112+
.unwrap()
128113
}))
129114
}
130115

@@ -140,7 +125,7 @@ impl egglog::extract::CostModel<Cost> for CostModel {
140125
// this is not needed because the only thing we can do with the output is look up an analysis
141126
// which we can also do with the original function
142127
values.pop().unwrap();
143-
Cost(self.enode_cost.call1(py, (func.name(), values)))
128+
Cost(self.enode_cost.call1(py, (func.name(), values)).unwrap())
144129
})
145130
}
146131

@@ -155,10 +140,11 @@ impl egglog::extract::CostModel<Cost> for CostModel {
155140
let element_costs = element_costs
156141
.into_iter()
157142
.cloned()
158-
.map(|c| c.0.map(|v| v.clone_ref(py)))
159-
.collect::<PyResult<Vec<_>>>()?;
143+
.map(|c| c.0.clone_ref(py))
144+
.collect::<Vec<_>>();
160145
self.container_cost
161146
.call1(py, (sort.name(), Value(value), element_costs))
147+
.unwrap()
162148
}))
163149
}
164150

@@ -169,7 +155,13 @@ impl egglog::extract::CostModel<Cost> for CostModel {
169155
sort: &egglog::ArcSort,
170156
value: egglog::Value,
171157
) -> Cost {
172-
Python::attach(|py| Cost(self.base_value_cost.call1(py, (sort.name(), Value(value)))))
158+
Python::attach(|py| {
159+
Cost(
160+
self.base_value_cost
161+
.call1(py, (sort.name(), Value(value)))
162+
.unwrap(),
163+
)
164+
})
173165
}
174166
}
175167

@@ -233,7 +225,7 @@ impl Extractor {
233225
.0
234226
.extract_best_with_sort(&egraph.egraph, &mut termdag.0, value.0, sort.clone())
235227
.ok_or(PyValueError::new_err("Unextractable root".to_string()))?;
236-
Ok((cost.0?.clone_ref(py), term.into()))
228+
Ok((cost.0.clone_ref(py), term.into()))
237229
}
238230

239231
/// Extract variants of an e-class.
@@ -260,9 +252,9 @@ impl Extractor {
260252
nvariants,
261253
sort.clone(),
262254
);
263-
variants
255+
Ok(variants
264256
.into_iter()
265-
.map(|(cost, term)| (cost.0.map(|c| (c.clone_ref(py), term.into()))))
266-
.collect()
257+
.map(|(cost, term)| (cost.0.clone_ref(py), term.into()))
258+
.collect())
267259
}
268260
}

0 commit comments

Comments
 (0)