Skip to content

Commit b33d55e

Browse files
Merge pull request #372 from egraphs-good/codex/github-mention-make-docs-builds-fail-on-notebook-execution
Fix docs execution for community talk and indexing notebooks
2 parents 4c1344d + f5b5ff2 commit b33d55e

File tree

8 files changed

+854
-8388
lines changed

8 files changed

+854
-8388
lines changed

docs/conf.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,6 @@
166166

167167
# Exclude (POSIX) glob patterns for notebooks
168168
# Temporarily exclude notebooks with unrelated errors (not @egraph.class_ issues)
169-
nb_execution_excludepatterns = (
170-
"explanation/2024_03_17_community_talk.ipynb", # sklearn config error
171-
"explanation/indexing_pushdown.ipynb", # array_api_module NameError
172-
)
173-
174169
# Execution timeout (seconds)
175170
nb_execution_timeout = 60 * 10
176171

docs/explanation/2024_03_17_community_talk.ipynb

Lines changed: 830 additions & 8356 deletions
Large diffs are not rendered by default.

docs/explanation/indexing_pushdown.ipynb

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@
257257
"\n",
258258
"from egglog.exp.array_api import *\n",
259259
"\n",
260-
"egraph = EGraph([array_api_module])\n",
260+
"egraph = EGraph()\n",
261261
"\n",
262262
"\n",
263263
"@egraph.register\n",
@@ -267,6 +267,7 @@
267267
"\n",
268268
"res = abs(NDArray.var(\"x\"))[NDArray.var(\"idx\")]\n",
269269
"egraph.register(res)\n",
270+
"egraph.run(array_api_schedule)\n",
270271
"egraph.run(100)\n",
271272
"egraph.display()\n",
272273
"\n",
@@ -720,7 +721,7 @@
720721
}
721722
],
722723
"source": [
723-
"egraph = EGraph([array_api_module])\n",
724+
"egraph = EGraph()\n",
724725
"\n",
725726
"\n",
726727
"@function(cost=0)\n",
@@ -758,6 +759,7 @@
758759
"\n",
759760
"\n",
760761
"egraph.register(res.shape, res.dtype, res.index(an_index()))\n",
762+
"egraph.run(array_api_schedule)\n",
761763
"egraph.run(100)\n",
762764
"egraph.display()\n",
763765
"\n",
@@ -807,4 +809,4 @@
807809
},
808810
"nbformat": 4,
809811
"nbformat_minor": 2
810-
}
812+
}

docs/tutorials/getting-started.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@
409409
],
410410
"source": [
411411
"egraph.run(10)\n",
412-
"egraph.extract(res)\n"
412+
"egraph.extract(res)"
413413
]
414414
},
415415
{
@@ -1098,7 +1098,7 @@
10981098
"# Create an example which should equal the kronecker product of A and B\n",
10991099
"ex1 = kron(Matrix.identity(n), B) @ kron(A, Matrix.identity(m))\n",
11001100
"egraph.run(20)\n",
1101-
"egraph.extract(ex1)\n"
1101+
"egraph.extract(ex1)"
11021102
]
11031103
},
11041104
{
@@ -1212,7 +1212,7 @@
12121212
"source": [
12131213
"ex2 = kron(Matrix.identity(p), C) @ kron(A, Matrix.identity(m))\n",
12141214
"egraph.run(20)\n",
1215-
"egraph.extract(ex2)\n"
1215+
"egraph.extract(ex2)"
12161216
]
12171217
},
12181218
{

python/egglog/egraph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,17 +1019,17 @@ def extract_multiple(self, expr: BASE_EXPR, n: int) -> list[BASE_EXPR]:
10191019
return [cast("BASE_EXPR", RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs]
10201020

10211021
def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._CommandOutput:
1022-
expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
1022+
egg_expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
10231023
# If we have defined any cost tables use the custom extraction
1024-
args = (expr, bindings.Lit(span(2), bindings.Int(n)))
1024+
args = (egg_expr, bindings.Lit(span(2), bindings.Int(n)))
10251025
if self._state.cost_callables:
10261026
cmd: bindings._Command = bindings.UserDefined(span(2), "extract", list(args))
10271027
else:
10281028
cmd = bindings.Extract(span(2), *args)
10291029
try:
10301030
return self._egraph.run_program(cmd)[0]
10311031
except BaseException as e:
1032-
raise add_note("Extracting: " + str(expr), e) # noqa: B904
1032+
raise add_note("while extracting expr:\n" + str(expr), e) # noqa: B904
10331033

10341034
def push(self) -> None:
10351035
"""

src/egraph.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,15 @@ impl EGraph {
7474
cmds_str = cmds_str + &cmd.to_string() + "\n";
7575
}
7676
info!("Running commands:\n{}", cmds_str);
77-
let res = py.detach(|| {
78-
self.egraph.run_program(commands).map_err(|e| {
79-
WrappedError::Egglog(e, "\nWhen running commands:\n".to_string() + &cmds_str)
80-
})
81-
});
82-
if res.is_ok()
83-
&& let Some(cmds) = &mut self.cmds
84-
{
85-
cmds.push_str(&cmds_str);
77+
match py.detach(|| self.egraph.run_program(commands)) {
78+
Err(e) => Err(WrappedError::Egglog(e)),
79+
Ok(outputs) => {
80+
if let Some(cmds) = &mut self.cmds {
81+
cmds.push_str(&cmds_str);
82+
}
83+
Ok(outputs.into_iter().map(|o| o.into()).collect())
84+
}
8685
}
87-
res.map(|xs| xs.iter().map(|o| o.into()).collect())
8886
}
8987

9088
/// Returns the text of the commands that have been run so far, if `record` was passed.
@@ -139,7 +137,7 @@ impl EGraph {
139137
self.egraph
140138
.eval_expr(&expr)
141139
.map(|(s, v)| (s.name().to_string(), Value(v)))
142-
.map_err(|e| WrappedError::Egglog(e, format!("\nWhen evaluating expr: {expr}")))
140+
.map_err(|e| WrappedError::Egglog(e))
143141
}
144142

145143
fn value_to_i64(&self, v: Value) -> i64 {

src/error.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ impl EggSmolError {
2121
// https://pyo3.rs/latest/function/error_handling.html#foreign-rust-error-types
2222
// TODO: Create classes for each of these errors
2323
pub enum WrappedError {
24-
// Add additional context for egglog error
25-
Egglog(egglog::Error, String),
24+
Egglog(egglog::Error),
2625
ParseError(egglog::ast::ParseError),
2726
Py(PyErr),
2827
}
@@ -31,9 +30,7 @@ pub enum WrappedError {
3130
impl From<WrappedError> for PyErr {
3231
fn from(error: WrappedError) -> Self {
3332
match error {
34-
WrappedError::Egglog(error, str) => {
35-
PyErr::new::<EggSmolError, _>(error.to_string() + &str)
36-
}
33+
WrappedError::Egglog(error) => PyErr::new::<EggSmolError, _>(error.to_string()),
3734
WrappedError::Py(error) => error,
3835
WrappedError::ParseError(error) => PyErr::new::<EggSmolError, _>(error.to_string()),
3936
}
@@ -43,7 +40,7 @@ impl From<WrappedError> for PyErr {
4340
// Convert from an egglog::Error to a WrappedError
4441
impl From<egglog::Error> for WrappedError {
4542
fn from(other: egglog::Error) -> Self {
46-
Self::Egglog(other, String::new())
43+
Self::Egglog(other)
4744
}
4845
}
4946

src/extract.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ impl egglog::extract::CostModel<Cost> for CostModel {
115115

116116
fn enode_cost(
117117
&self,
118-
egraph: &egglog::EGraph,
118+
_egraph: &egglog::EGraph,
119119
func: &egglog::Function,
120120
row: &egglog::FunctionRow<'_>,
121121
) -> Cost {

0 commit comments

Comments
 (0)