Skip to content

Commit d282494

Browse files
Merge pull request #267 from egraphs-good/fix-benchmark
Simplify LDA benchmarks even more
2 parents 4fd253f + 6856926 commit d282494

File tree

8 files changed

+33
-102
lines changed

8 files changed

+33
-102
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
with:
3737
python-version: ${{ matrix.py }}
3838
- run: uv sync --extra test --locked
39-
- run: uv run pytest --benchmark-disable -vvv
39+
- run: uv run pytest --benchmark-disable -vvv --durations=10
4040

4141
mypy:
4242
runs-on: ubuntu-latest

python/egglog/exp/array_api_jit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,6 @@ def jit(fn: X) -> X:
2323

2424
fn_program = ndarray_function_two(res_optimized, NDArray.var(arg1), NDArray.var(arg2))
2525
fn = try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object)
26+
fn.initial_expr = res # type: ignore[attr-defined]
2627
fn.expr = res_optimized # type: ignore[attr-defined]
2728
return cast(X, fn)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
NDArray.var("x") + NDArray.var("y")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
NDArray.var("x")[IndexKey.int((NDArray.var("x").shape + TupleInt.from_vec(Vec[Int](Int(1), Int(2))))[Int(100)])]

python/tests/test_array_api.py

Lines changed: 29 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
# mypy: disable-error-code="empty-body"
2-
import ast
32
import inspect
43
from collections.abc import Callable
54
from itertools import product
65
from pathlib import Path
76
from types import FunctionType
8-
from typing import Any
97

108
import numba
119
import pytest
@@ -306,42 +304,6 @@ def run_lda(x, y):
306304
X_np, y_np = (iris.data, iris.target)
307305

308306

309-
def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any:
310-
"""
311-
Load a python snapshot, evaling the code, and returning the `var` defined in it.
312-
313-
If no var is provided, then return the last expression.
314-
"""
315-
path = Path(__file__).parent / "__snapshots__" / "test_array_api" / f"TestLDA.{fn.__name__}.py"
316-
contents = path.read_text()
317-
318-
contents = "import numpy as np\nfrom egglog.exp.array_api import *\n" + contents
319-
globals: dict[str, Any] = {}
320-
if var is None:
321-
# exec once as a full statement
322-
exec(contents, globals)
323-
# Eval the last statement
324-
last_expr = ast.unparse(ast.parse(contents).body[-1])
325-
return eval(last_expr, globals)
326-
exec(contents, globals)
327-
return globals[var]
328-
329-
330-
def lda(X: NDArray, y: NDArray):
331-
assume_dtype(X, X_np.dtype)
332-
assume_shape(X, X_np.shape)
333-
assume_isfinite(X)
334-
335-
assume_dtype(y, y_np.dtype)
336-
assume_shape(y, y_np.shape)
337-
assume_value_one_of(y, tuple(map(int, np.unique(y_np))))
338-
return run_lda(X, y)
339-
340-
341-
def lda_filled():
342-
return lda(NDArray.var("X"), NDArray.var("y"))
343-
344-
345307
@pytest.mark.parametrize(
346308
"program",
347309
[
@@ -364,80 +326,46 @@ def test_program_compile(program: Program, snapshot_py):
364326
assert "\n".join([*statements.split("\n"), expr]) == snapshot_py(name="code")
365327

366328

329+
def lda(X: NDArray, y: NDArray):
330+
assume_dtype(X, X_np.dtype)
331+
assume_shape(X, X_np.shape)
332+
assume_isfinite(X)
333+
334+
assume_dtype(y, y_np.dtype)
335+
assume_shape(y, y_np.shape)
336+
assume_value_one_of(y, tuple(map(int, np.unique(y_np))))
337+
return run_lda(X, y)
338+
339+
367340
@pytest.mark.parametrize(
368341
"program",
369342
[
370343
pytest.param(lambda x, y: x + y, id="add"),
371344
pytest.param(lambda x, y: x[(x.shape + TupleInt.from_vec((1, 2)))[100]], id="tuple"),
345+
pytest.param(lda, id="lda"),
372346
],
373347
)
374-
def test_jit(program, snapshot_py):
375-
jitted = jit(program)
348+
def test_jit(program, snapshot_py, benchmark):
349+
jitted = benchmark(jit, program)
350+
assert str(jitted.initial_expr) == snapshot_py(name="initial_expr")
376351
assert str(jitted.expr) == snapshot_py(name="expr")
377352
assert inspect.getsource(jitted) == snapshot_py(name="code")
378353

379354

380-
@pytest.mark.benchmark(min_rounds=3)
381-
class TestLDA:
382-
"""
383-
Incrementally benchmark each part of the LDA to see how long it takes to run.
384-
"""
385-
386-
def test_trace(self, snapshot_py, benchmark):
387-
@benchmark
388-
def X_r2():
389-
with EGraph().set_current():
390-
return lda_filled()
391-
392-
res = str(X_r2)
393-
assert res == snapshot_py
394-
395-
def test_optimize(self, snapshot_py, benchmark):
396-
egraph = EGraph()
397-
with egraph.set_current():
398-
expr = lda_filled()
399-
simplified = benchmark(egraph.simplify, expr, array_api_numba_schedule)
400-
401-
assert str(simplified) == snapshot_py
402-
403-
# @pytest.mark.xfail(reason="Original source is not working")
404-
# def test_source(self, snapshot_py, benchmark):
405-
# egraph = EGraph()
406-
# expr = trace_lda(egraph)
407-
# assert benchmark(load_source, expr, egraph) == snapshot_py
408-
409-
def test_source_optimized(self, snapshot_py, benchmark):
410-
egraph = EGraph()
411-
with egraph.set_current():
412-
expr = lda_filled()
413-
optimized_expr = egraph.simplify(expr, array_api_numba_schedule)
414-
415-
@benchmark
416-
def py_object():
417-
fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y"))
418-
return try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object)
419-
420-
assert np.allclose(py_object(X_np, y_np), run_lda(X_np, y_np))
421-
assert inspect.getsource(py_object) == snapshot_py
422-
423-
@pytest.mark.parametrize(
424-
"fn_thunk",
425-
[
426-
pytest.param(lambda: LinearDiscriminantAnalysis(n_components=2).fit_transform, id="base"),
427-
pytest.param(lambda: run_lda, id="array_api"),
428-
pytest.param(lambda: _load_py_snapshot(TestLDA.test_source_optimized, "__fn"), id="array_api-optimized"),
429-
pytest.param(
430-
lambda: numba.njit(_load_py_snapshot(TestLDA.test_source_optimized, "__fn")),
431-
id="array_api-optimized-numba",
432-
),
433-
pytest.param(lambda: jit(lda), id="array_api-jit"),
434-
],
435-
)
436-
def test_execution(self, fn_thunk, benchmark):
437-
fn = fn_thunk()
438-
# warmup once for numba
439-
assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np), rtol=1e-03)
440-
benchmark(fn, X_np, y_np)
355+
@pytest.mark.parametrize(
356+
"fn_thunk",
357+
[
358+
pytest.param(lambda: LinearDiscriminantAnalysis(n_components=2).fit_transform, id="base"),
359+
pytest.param(lambda: run_lda, id="array_api"),
360+
pytest.param(lambda: jit(lda), id="array_api-optimized"),
361+
pytest.param(lambda: numba.njit(jit(lda)), id="array_api-optimized-numba"),
362+
],
363+
)
364+
def test_run_lda(fn_thunk, benchmark):
365+
fn = fn_thunk()
366+
# warmup once for numba
367+
assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np), rtol=1e-03)
368+
benchmark(fn, X_np, y_np)
441369

442370

443371
# if calling as script, print out egglog source for test

0 commit comments

Comments
 (0)