Skip to content

Commit 4fd253f

Browse files
Merge pull request #266 from egraphs-good/fix-benchmark
More accurate local benchmarking
2 parents 3489a6b + dd6b6f6 commit 4fd253f

File tree

1 file changed

+22
-38
lines changed

1 file changed

+22
-38
lines changed

python/tests/test_array_api.py

Lines changed: 22 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from egglog.exp.array_api_loopnest import *
2020
from egglog.exp.array_api_numba import array_api_numba_schedule
2121
from egglog.exp.array_api_program_gen import *
22-
from egglog.exp.program_gen import Program
22+
from egglog.exp.program_gen import EvalProgram, Program
2323

2424
some_shape = constant("some_shape", TupleInt)
2525
some_dtype = constant("some_dtype", DType)
@@ -327,33 +327,19 @@ def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any:
327327
return globals[var]
328328

329329

330-
def load_source(fn_program: EvalProgram, egraph: EGraph):
331-
egraph.register(fn_program)
332-
egraph.run(array_api_program_gen_schedule)
333-
# dp the needed pieces in here for benchmarking
334-
try:
335-
return egraph.extract(fn_program.as_py_object).eval()
336-
except Exception as err:
337-
err.add_note(f"Failed to compile the program into a string: \n\n{egraph.extract(fn_program)}")
338-
egraph.display(split_primitive_outputs=True, n_inline_leaves=3, split_functions=[Program])
339-
raise
340-
341-
342-
def lda(X, y):
330+
def lda(X: NDArray, y: NDArray):
343331
assume_dtype(X, X_np.dtype)
344332
assume_shape(X, X_np.shape)
345333
assume_isfinite(X)
346334

347335
assume_dtype(y, y_np.dtype)
348336
assume_shape(y, y_np.shape)
349-
assume_value_one_of(y, tuple(map(int, np.unique(y_np)))) # type: ignore[arg-type]
337+
assume_value_one_of(y, tuple(map(int, np.unique(y_np))))
350338
return run_lda(X, y)
351339

352340

353-
def simplify_lda(egraph: EGraph, expr: NDArray) -> NDArray:
354-
egraph.register(expr)
355-
egraph.run(array_api_numba_schedule)
356-
return egraph.extract(expr)
341+
def lda_filled():
342+
return lda(NDArray.var("X"), NDArray.var("y"))
357343

358344

359345
@pytest.mark.parametrize(
@@ -398,21 +384,20 @@ class TestLDA:
398384
"""
399385

400386
def test_trace(self, snapshot_py, benchmark):
401-
X = NDArray.var("X")
402-
y = NDArray.var("y")
403-
with EGraph().set_current():
404-
X_r2 = benchmark(lda, X, y)
387+
@benchmark
388+
def X_r2():
389+
with EGraph().set_current():
390+
return lda_filled()
391+
405392
res = str(X_r2)
406-
print(res)
407393
assert res == snapshot_py
408394

409395
def test_optimize(self, snapshot_py, benchmark):
410396
egraph = EGraph()
411-
X = NDArray.var("X")
412-
y = NDArray.var("y")
413397
with egraph.set_current():
414-
expr = lda(X, y)
415-
simplified = benchmark(simplify_lda, egraph, expr)
398+
expr = lda_filled()
399+
simplified = benchmark(egraph.simplify, expr, array_api_numba_schedule)
400+
416401
assert str(simplified) == snapshot_py
417402

418403
# @pytest.mark.xfail(reason="Original source is not working")
@@ -423,18 +408,17 @@ def test_optimize(self, snapshot_py, benchmark):
423408

424409
def test_source_optimized(self, snapshot_py, benchmark):
425410
egraph = EGraph()
426-
X = NDArray.var("X")
427-
y = NDArray.var("y")
428411
with egraph.set_current():
429-
expr = lda(X, y)
430-
optimized_expr = simplify_lda(egraph, expr)
431-
egraph = EGraph()
432-
fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y"))
433-
py_object = benchmark(load_source, fn_program, egraph)
412+
expr = lda_filled()
413+
optimized_expr = egraph.simplify(expr, array_api_numba_schedule)
414+
415+
@benchmark
416+
def py_object():
417+
fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y"))
418+
return try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object)
419+
434420
assert np.allclose(py_object(X_np, y_np), run_lda(X_np, y_np))
435-
with egraph.set_current():
436-
fn_object = cast(FunctionType, fn_program.as_py_object.eval())
437-
assert inspect.getsource(fn_object) == snapshot_py
421+
assert inspect.getsource(py_object) == snapshot_py
438422

439423
@pytest.mark.parametrize(
440424
"fn_thunk",

0 commit comments

Comments
 (0)