Skip to content

Commit 34b5561

Browse files
Don't emit so much duplicate info when an error is produced
1 parent 371d89a commit 34b5561

File tree

3 files changed

+15
-20
lines changed

3 files changed

+15
-20
lines changed

python/egglog/egraph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,17 +1019,17 @@ def extract_multiple(self, expr: BASE_EXPR, n: int) -> list[BASE_EXPR]:
10191019
return [cast("BASE_EXPR", RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs]
10201020

10211021
def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._CommandOutput:
1022-
expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
1022+
egg_expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
10231023
# If we have defined any cost tables use the custom extraction
1024-
args = (expr, bindings.Lit(span(2), bindings.Int(n)))
1024+
args = (egg_expr, bindings.Lit(span(2), bindings.Int(n)))
10251025
if self._state.cost_callables:
10261026
cmd: bindings._Command = bindings.UserDefined(span(2), "extract", list(args))
10271027
else:
10281028
cmd = bindings.Extract(span(2), *args)
10291029
try:
10301030
return self._egraph.run_program(cmd)[0]
10311031
except BaseException as e:
1032-
raise add_note("Extracting: " + str(expr), e) # noqa: B904
1032+
raise add_note("while extracting expr:\n" + str(expr), e) # noqa: B904
10331033

10341034
def push(self) -> None:
10351035
"""

src/egraph.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,15 @@ impl EGraph {
7474
cmds_str = cmds_str + &cmd.to_string() + "\n";
7575
}
7676
info!("Running commands:\n{}", cmds_str);
77-
let res = py.detach(|| {
78-
self.egraph.run_program(commands).map_err(|e| {
79-
WrappedError::Egglog(e, "\nWhen running commands:\n".to_string() + &cmds_str)
80-
})
81-
});
82-
if res.is_ok()
83-
&& let Some(cmds) = &mut self.cmds
84-
{
85-
cmds.push_str(&cmds_str);
77+
match py.detach(|| self.egraph.run_program(commands)) {
78+
Err(e) => Err(WrappedError::Egglog(e)),
79+
Ok(outputs) => {
80+
if let Some(cmds) = &mut self.cmds {
81+
cmds.push_str(&cmds_str);
82+
}
83+
Ok(outputs.into_iter().map(|o| o.into()).collect())
84+
}
8685
}
87-
res.map(|xs| xs.iter().map(|o| o.into()).collect())
8886
}
8987

9088
/// Returns the text of the commands that have been run so far, if `record` was passed.
@@ -139,7 +137,7 @@ impl EGraph {
139137
self.egraph
140138
.eval_expr(&expr)
141139
.map(|(s, v)| (s.name().to_string(), Value(v)))
142-
.map_err(|e| WrappedError::Egglog(e, format!("\nWhen evaluating expr: {expr}")))
140+
.map_err(|e| WrappedError::Egglog(e))
143141
}
144142

145143
fn value_to_i64(&self, v: Value) -> i64 {

src/error.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ impl EggSmolError {
2121
// https://pyo3.rs/latest/function/error_handling.html#foreign-rust-error-types
2222
// TODO: Create classes for each of these errors
2323
pub enum WrappedError {
24-
// Add additional context for egglog error
25-
Egglog(egglog::Error, String),
24+
Egglog(egglog::Error),
2625
ParseError(egglog::ast::ParseError),
2726
Py(PyErr),
2827
}
@@ -31,9 +30,7 @@ pub enum WrappedError {
3130
impl From<WrappedError> for PyErr {
3231
fn from(error: WrappedError) -> Self {
3332
match error {
34-
WrappedError::Egglog(error, str) => {
35-
PyErr::new::<EggSmolError, _>(error.to_string() + &str)
36-
}
33+
WrappedError::Egglog(error) => PyErr::new::<EggSmolError, _>(error.to_string()),
3734
WrappedError::Py(error) => error,
3835
WrappedError::ParseError(error) => PyErr::new::<EggSmolError, _>(error.to_string()),
3936
}
@@ -43,7 +40,7 @@ impl From<WrappedError> for PyErr {
4340
// Convert from an egglog::Error to a WrappedError
4441
impl From<egglog::Error> for WrappedError {
4542
fn from(other: egglog::Error) -> Self {
46-
Self::Egglog(other, String::new())
43+
Self::Egglog(other)
4744
}
4845
}
4946

0 commit comments

Comments
 (0)