Skip to content

Commit a7a26d1

Browse files
Merge pull request #218 from egraphs-good/update-benchmarks
Update benchmarks to make it easier to generate a complete egglog string
2 parents 12d8d5f + eac2634 commit a7a26d1

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

python/egglog/exp/array_api.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,20 +1538,23 @@ def try_evaling(expr: Expr, prim_expr: i64 | Bool) -> int | bool:
15381538
egraph.register(expr)
15391539
egraph.run(array_api_schedule)
15401540
try:
1541-
return egraph.eval(prim_expr)
1541+
extracted = egraph.extract(prim_expr)
15421542
except EggSmolError as exc:
1543-
egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
1543+
# Try giving some context, by showing the smallest version of the larger expression
15441544
try:
1545-
msg = f"Cannot simplify to primitive {egraph.extract(expr)}"
1546-
except EggSmolError:
1547-
msg = f"Cannot simplify to primitive or extract {expr}"
1548-
1549-
# string = (
1550-
# egraph.as_egglog_string
1551-
# + "\n"
1552-
# + str(egraph._state.typed_expr_to_egg(cast(RuntimeExpr, prim_expr).__egg_typed_expr__))
1553-
# )
1554-
# # save to "tmp.egg"
1555-
# with open("tmp.egg", "w") as f:
1556-
# f.write(string)
1545+
expr_extracted = egraph.extract(expr)
1546+
except EggSmolError as inner_exc:
1547+
raise ValueError(f"Cannot simplify {expr}") from inner_exc
1548+
egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
1549+
msg = f"Cannot simplify to primitive {expr_extracted}"
15571550
raise ValueError(msg) from exc
1551+
return egraph.eval(extracted)
1552+
1553+
# string = (
1554+
# egraph.as_egglog_string
1555+
# + "\n"
1556+
# + str(egraph._state.typed_expr_to_egg(cast(RuntimeExpr, prim_expr).__egg_typed_expr__))
1557+
# )
1558+
# # save to "tmp.egg"
1559+
# with open("tmp.egg", "w") as f:
1560+
# f.write(string)

python/tests/test_array_api.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import ast
22
from collections.abc import Callable
33
from pathlib import Path
4-
from typing import Any, cast
4+
from typing import Any
55

66
import numba
77
import pytest
@@ -79,7 +79,7 @@ def run_lda(x, y):
7979

8080
iris = datasets.load_iris()
8181
X_np, y_np = (iris.data, iris.target)
82-
res = run_lda(X_np, y_np)
82+
res_np = run_lda(X_np, y_np)
8383

8484

8585
def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any:
@@ -107,10 +107,7 @@ def load_source(expr, egraph: EGraph):
107107
with egraph:
108108
fn_program = egraph.let("fn_program", ndarray_function_two(expr, NDArray.var("X"), NDArray.var("y")))
109109
egraph.run(array_api_program_gen_schedule)
110-
# cast b/c issue with it not recognizing py_object as property
111-
cast(Any, egraph.eval(fn_program.py_object))
112-
assert np.allclose(res, run_lda(X_np, y_np))
113-
return egraph.eval(fn_program.statements)
110+
return egraph.eval(egraph.extract(fn_program.statements))
114111

115112

116113
def trace_lda(egraph: EGraph):
@@ -163,5 +160,5 @@ def test_source_optimized(self, snapshot_py, benchmark):
163160
)
164161
def test_execution(self, fn, benchmark):
165162
# warmup once for numba
166-
assert np.allclose(res, fn(X_np, y_np))
163+
assert np.allclose(res_np, fn(X_np, y_np))
167164
benchmark(fn, X_np, y_np)

0 commit comments

Comments
 (0)