Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,6 @@

# Exclude (POSIX) glob patterns for notebooks
# Temporarily exclude notebooks with unrelated errors (not @egraph.class_ issues)
nb_execution_excludepatterns = (
"explanation/2024_03_17_community_talk.ipynb", # sklearn config error
"explanation/indexing_pushdown.ipynb", # array_api_module NameError
)

# Execution timeout (seconds)
nb_execution_timeout = 60 * 10

Expand Down
9,186 changes: 830 additions & 8,356 deletions docs/explanation/2024_03_17_community_talk.ipynb

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions docs/explanation/indexing_pushdown.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@
"\n",
"from egglog.exp.array_api import *\n",
"\n",
"egraph = EGraph([array_api_module])\n",
"egraph = EGraph()\n",
"\n",
"\n",
"@egraph.register\n",
Expand All @@ -267,6 +267,7 @@
"\n",
"res = abs(NDArray.var(\"x\"))[NDArray.var(\"idx\")]\n",
"egraph.register(res)\n",
"egraph.run(array_api_schedule)\n",
"egraph.run(100)\n",
"egraph.display()\n",
"\n",
Expand Down Expand Up @@ -720,7 +721,7 @@
}
],
"source": [
"egraph = EGraph([array_api_module])\n",
"egraph = EGraph()\n",
"\n",
"\n",
"@function(cost=0)\n",
Expand Down Expand Up @@ -758,6 +759,7 @@
"\n",
"\n",
"egraph.register(res.shape, res.dtype, res.index(an_index()))\n",
"egraph.run(array_api_schedule)\n",
"egraph.run(100)\n",
"egraph.display()\n",
"\n",
Expand Down Expand Up @@ -807,4 +809,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
6 changes: 3 additions & 3 deletions docs/tutorials/getting-started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@
],
"source": [
"egraph.run(10)\n",
"egraph.extract(res)\n"
"egraph.extract(res)"
]
},
{
Expand Down Expand Up @@ -1098,7 +1098,7 @@
"# Create an example which should equal the kronecker product of A and B\n",
"ex1 = kron(Matrix.identity(n), B) @ kron(A, Matrix.identity(m))\n",
"egraph.run(20)\n",
"egraph.extract(ex1)\n"
"egraph.extract(ex1)"
]
},
{
Expand Down Expand Up @@ -1212,7 +1212,7 @@
"source": [
"ex2 = kron(Matrix.identity(p), C) @ kron(A, Matrix.identity(m))\n",
"egraph.run(20)\n",
"egraph.extract(ex2)\n"
"egraph.extract(ex2)"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,17 +1019,17 @@ def extract_multiple(self, expr: BASE_EXPR, n: int) -> list[BASE_EXPR]:
return [cast("BASE_EXPR", RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs]

def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._CommandOutput:
expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
egg_expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
# If we have defined any cost tables use the custom extraction
args = (expr, bindings.Lit(span(2), bindings.Int(n)))
args = (egg_expr, bindings.Lit(span(2), bindings.Int(n)))
if self._state.cost_callables:
cmd: bindings._Command = bindings.UserDefined(span(2), "extract", list(args))
else:
cmd = bindings.Extract(span(2), *args)
try:
return self._egraph.run_program(cmd)[0]
except BaseException as e:
raise add_note("Extracting: " + str(expr), e) # noqa: B904
raise add_note("while extracting expr:\n" + str(expr), e) # noqa: B904

def push(self) -> None:
"""
Expand Down
20 changes: 9 additions & 11 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,15 @@ impl EGraph {
cmds_str = cmds_str + &cmd.to_string() + "\n";
}
info!("Running commands:\n{}", cmds_str);
let res = py.detach(|| {
self.egraph.run_program(commands).map_err(|e| {
WrappedError::Egglog(e, "\nWhen running commands:\n".to_string() + &cmds_str)
})
});
if res.is_ok()
&& let Some(cmds) = &mut self.cmds
{
cmds.push_str(&cmds_str);
match py.detach(|| self.egraph.run_program(commands)) {
Err(e) => Err(WrappedError::Egglog(e)),
Ok(outputs) => {
if let Some(cmds) = &mut self.cmds {
cmds.push_str(&cmds_str);
}
Ok(outputs.into_iter().map(|o| o.into()).collect())
}
}
res.map(|xs| xs.iter().map(|o| o.into()).collect())
}

/// Returns the text of the commands that have been run so far, if `record` was passed.
Expand Down Expand Up @@ -139,7 +137,7 @@ impl EGraph {
self.egraph
.eval_expr(&expr)
.map(|(s, v)| (s.name().to_string(), Value(v)))
.map_err(|e| WrappedError::Egglog(e, format!("\nWhen evaluating expr: {expr}")))
.map_err(|e| WrappedError::Egglog(e))
}

fn value_to_i64(&self, v: Value) -> i64 {
Expand Down
9 changes: 3 additions & 6 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ impl EggSmolError {
// https://pyo3.rs/latest/function/error_handling.html#foreign-rust-error-types
// TODO: Create classes for each of these errors
pub enum WrappedError {
// Add additional context for egglog error
Egglog(egglog::Error, String),
Egglog(egglog::Error),
ParseError(egglog::ast::ParseError),
Py(PyErr),
}
Expand All @@ -31,9 +30,7 @@ pub enum WrappedError {
impl From<WrappedError> for PyErr {
fn from(error: WrappedError) -> Self {
match error {
WrappedError::Egglog(error, str) => {
PyErr::new::<EggSmolError, _>(error.to_string() + &str)
}
WrappedError::Egglog(error) => PyErr::new::<EggSmolError, _>(error.to_string()),
WrappedError::Py(error) => error,
WrappedError::ParseError(error) => PyErr::new::<EggSmolError, _>(error.to_string()),
}
Expand All @@ -43,7 +40,7 @@ impl From<WrappedError> for PyErr {
// Convert from an egglog::Error to a WrappedError
impl From<egglog::Error> for WrappedError {
fn from(other: egglog::Error) -> Self {
Self::Egglog(other, String::new())
Self::Egglog(other)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ impl egglog::extract::CostModel<Cost> for CostModel {

fn enode_cost(
&self,
egraph: &egglog::EGraph,
_egraph: &egglog::EGraph,
func: &egglog::Function,
row: &egglog::FunctionRow<'_>,
) -> Cost {
Expand Down