Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions python/egglog/exp/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,20 +1538,23 @@ def try_evaling(expr: Expr, prim_expr: i64 | Bool) -> int | bool:
egraph.register(expr)
egraph.run(array_api_schedule)
try:
return egraph.eval(prim_expr)
extracted = egraph.extract(prim_expr)
except EggSmolError as exc:
egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
# Try giving some context, by showing the smallest version of the larger expression
try:
msg = f"Cannot simplify to primitive {egraph.extract(expr)}"
except EggSmolError:
msg = f"Cannot simplify to primitive or extract {expr}"

# string = (
# egraph.as_egglog_string
# + "\n"
# + str(egraph._state.typed_expr_to_egg(cast(RuntimeExpr, prim_expr).__egg_typed_expr__))
# )
# # save to "tmp.egg"
# with open("tmp.egg", "w") as f:
# f.write(string)
expr_extracted = egraph.extract(expr)
except EggSmolError as inner_exc:
raise ValueError(f"Cannot simplify {expr}") from inner_exc
egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
msg = f"Cannot simplify to primitive {expr_extracted}"
raise ValueError(msg) from exc
return egraph.eval(extracted)

# string = (
# egraph.as_egglog_string
# + "\n"
# + str(egraph._state.typed_expr_to_egg(cast(RuntimeExpr, prim_expr).__egg_typed_expr__))
# )
# # save to "tmp.egg"
# with open("tmp.egg", "w") as f:
# f.write(string)
11 changes: 4 additions & 7 deletions python/tests/test_array_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast
from collections.abc import Callable
from pathlib import Path
from typing import Any, cast
from typing import Any

import numba
import pytest
Expand Down Expand Up @@ -79,7 +79,7 @@ def run_lda(x, y):

iris = datasets.load_iris()
X_np, y_np = (iris.data, iris.target)
res = run_lda(X_np, y_np)
res_np = run_lda(X_np, y_np)


def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any:
Expand Down Expand Up @@ -107,10 +107,7 @@ def load_source(expr, egraph: EGraph):
with egraph:
fn_program = egraph.let("fn_program", ndarray_function_two(expr, NDArray.var("X"), NDArray.var("y")))
egraph.run(array_api_program_gen_schedule)
# cast b/c issue with it not recognizing py_object as property
cast(Any, egraph.eval(fn_program.py_object))
assert np.allclose(res, run_lda(X_np, y_np))
return egraph.eval(fn_program.statements)
return egraph.eval(egraph.extract(fn_program.statements))


def trace_lda(egraph: EGraph):
Expand Down Expand Up @@ -163,5 +160,5 @@ def test_source_optimized(self, snapshot_py, benchmark):
)
def test_execution(self, fn, benchmark):
# warmup once for numba
assert np.allclose(res, fn(X_np, y_np))
assert np.allclose(res_np, fn(X_np, y_np))
benchmark(fn, X_np, y_np)
Loading