From 406cfbdd11d53bef4af51fc89ca987a6f22a7f38 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 23 Oct 2024 10:31:27 -0400 Subject: [PATCH 1/5] Update benchmarks to make it easier to generate a complete egglog string --- python/egglog/exp/array_api.py | 9 +++++---- python/tests/test_array_api.py | 23 ++++++++++------------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/python/egglog/exp/array_api.py b/python/egglog/exp/array_api.py index f6b062a8..35754af7 100644 --- a/python/egglog/exp/array_api.py +++ b/python/egglog/exp/array_api.py @@ -1537,14 +1537,15 @@ def try_evaling(expr: Expr, prim_expr: i64 | Bool) -> int | bool: egraph = EGraph.current() egraph.register(expr) egraph.run(array_api_schedule) + try: + extracted = egraph.extract(expr) + except EggSmolError as exc: + raise ValueError(f"Cannot extract {expr}") from exc try: return egraph.eval(prim_expr) except EggSmolError as exc: egraph.display(n_inline_leaves=1, split_primitive_outputs=True) - try: - msg = f"Cannot simplify to primitive {egraph.extract(expr)}" - except EggSmolError: - msg = f"Cannot simplify to primitive or extract {expr}" + msg = f"Cannot simplify to primitive {extracted}" # string = ( # egraph.as_egglog_string diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index 497e5c1a..8300b698 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -1,7 +1,7 @@ import ast from collections.abc import Callable from pathlib import Path -from typing import Any, cast +from typing import Any import numba import pytest @@ -79,7 +79,7 @@ def run_lda(x, y): iris = datasets.load_iris() X_np, y_np = (iris.data, iris.target) -res = run_lda(X_np, y_np) +res_np = run_lda(X_np, y_np) def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any: @@ -104,13 +104,9 @@ def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any: def load_source(expr, egraph: EGraph): - with egraph: - fn_program = egraph.let("fn_program", ndarray_function_two(expr, NDArray.var("X"), NDArray.var("y"))) - egraph.run(array_api_program_gen_schedule) - # cast b/c issue with it not recognizing py_object as property - cast(Any, egraph.eval(fn_program.py_object)) - assert np.allclose(res, run_lda(X_np, y_np)) - return egraph.eval(fn_program.statements) + fn_program = egraph.let("fn_program", ndarray_function_two(expr, NDArray.var("X"), NDArray.var("y"))) + egraph.run(array_api_program_gen_schedule) + return egraph.eval(egraph.extract(fn_program.statements)) def trace_lda(egraph: EGraph): @@ -147,9 +143,10 @@ def test_source(self, snapshot_py, benchmark): assert benchmark(load_source, expr, egraph) == snapshot_py def test_source_optimized(self, snapshot_py, benchmark): - egraph = EGraph() - expr = trace_lda(egraph) - optimized_expr = egraph.simplify(expr, array_api_numba_schedule) + egraph = EGraph(save_egglog_string=True) + with egraph: + expr = trace_lda(egraph) + optimized_expr = egraph.simplify(expr, array_api_numba_schedule) assert benchmark(load_source, optimized_expr, egraph) == snapshot_py @pytest.mark.parametrize( @@ -163,5 +160,5 @@ def test_source_optimized(self, snapshot_py, benchmark): ) def test_execution(self, fn, benchmark): # warmup once for numba - assert np.allclose(res, fn(X_np, y_np)) + assert np.allclose(res_np, fn(X_np, y_np)) benchmark(fn, X_np, y_np) From c8cc71e0f2b655e4a5abe44496709e027fa96c3d Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 23 Oct 2024 10:32:30 -0400 Subject: [PATCH 2/5] Don't save egglog string in benchmark --- python/tests/test_array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index 8300b698..74fdf5ba 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -143,7 +143,7 @@ def test_source(self, snapshot_py, benchmark): assert benchmark(load_source, expr, egraph) == snapshot_py def test_source_optimized(self, snapshot_py, benchmark): - egraph = EGraph(save_egglog_string=True) + egraph = EGraph() with egraph: expr = trace_lda(egraph) optimized_expr = egraph.simplify(expr, array_api_numba_schedule) From fd367974c3b98947814b94344a301261aa42c963 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 23 Oct 2024 15:08:31 -0400 Subject: [PATCH 3/5] Try putting in let binding --- python/tests/test_array_api.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index 74fdf5ba..2199760f 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -104,9 +104,10 @@ def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any: def load_source(expr, egraph: EGraph): - fn_program = egraph.let("fn_program", ndarray_function_two(expr, NDArray.var("X"), NDArray.var("y"))) - egraph.run(array_api_program_gen_schedule) - return egraph.eval(egraph.extract(fn_program.statements)) + with egraph: + fn_program = egraph.let("fn_program", ndarray_function_two(expr, NDArray.var("X"), NDArray.var("y"))) + egraph.run(array_api_program_gen_schedule) + return egraph.eval(egraph.extract(fn_program.statements)) def trace_lda(egraph: EGraph): From d43ba4be994378a10ee2fcfea316b47533da1448 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 23 Oct 2024 15:11:22 -0400 Subject: [PATCH 4/5] Remove a let --- python/tests/test_array_api.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index 2199760f..f127e7e9 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -145,9 +145,8 @@ def test_source(self, snapshot_py, benchmark): def test_source_optimized(self, snapshot_py, benchmark): egraph = EGraph() - with egraph: - expr = trace_lda(egraph) - optimized_expr = egraph.simplify(expr, array_api_numba_schedule) + expr = trace_lda(egraph) + optimized_expr = egraph.simplify(expr, array_api_numba_schedule) assert benchmark(load_source, optimized_expr, egraph) == snapshot_py @pytest.mark.parametrize( From eac26348095cda5f36580fe1dc00f4ab8822ffdd Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 23 Oct 2024 15:27:49 -0400 Subject: [PATCH 5/5] Improve extraction and error logic --- python/egglog/exp/array_api.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/python/egglog/exp/array_api.py b/python/egglog/exp/array_api.py index 35754af7..d494746b 100644 --- a/python/egglog/exp/array_api.py +++ b/python/egglog/exp/array_api.py @@ -1538,21 +1538,23 @@ def try_evaling(expr: Expr, prim_expr: i64 | Bool) -> int | bool: egraph.register(expr) egraph.run(array_api_schedule) try: - extracted = egraph.extract(expr) - except EggSmolError as exc: - raise ValueError(f"Cannot extract {expr}") from exc - try: - return egraph.eval(prim_expr) + extracted = egraph.extract(prim_expr) except EggSmolError as exc: + # Try giving some context, by showing the smallest version of the larger expression + try: + expr_extracted = egraph.extract(expr) + except EggSmolError as inner_exc: + raise ValueError(f"Cannot simplify {expr}") from inner_exc egraph.display(n_inline_leaves=1, split_primitive_outputs=True) - msg = f"Cannot simplify to primitive {extracted}" - - # string = ( - # egraph.as_egglog_string - # + "\n" - # + str(egraph._state.typed_expr_to_egg(cast(RuntimeExpr, prim_expr).__egg_typed_expr__)) - # ) - # # save to "tmp.egg" - # with open("tmp.egg", "w") as f: - # f.write(string) + msg = f"Cannot simplify to primitive {expr_extracted}" raise ValueError(msg) from exc + return egraph.eval(extracted) + + # string = ( + # egraph.as_egglog_string + # + "\n" + # + str(egraph._state.typed_expr_to_egg(cast(RuntimeExpr, prim_expr).__egg_typed_expr__)) + # ) + # # save to "tmp.egg" + # with open("tmp.egg", "w") as f: + # f.write(string)