Skip to content

Commit ee30c2b

Browse files
Simplify LDA benchmarks even more
Previously we were benchmarking different stages of compiling separately. This was hard to manage to make sure we captured everything properly. In this PR we just changed it to benchmark it all as one unit and can rely on profiling to determine what parts slowed down.
1 parent dd6b6f6 commit ee30c2b

File tree

7 files changed

+32
-101
lines changed

7 files changed

+32
-101
lines changed

python/egglog/exp/array_api_jit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,6 @@ def jit(fn: X) -> X:
2323

2424
fn_program = ndarray_function_two(res_optimized, NDArray.var(arg1), NDArray.var(arg2))
2525
fn = try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object)
26+
fn.initial_expr = res # type: ignore[attr-defined]
2627
fn.expr = res_optimized # type: ignore[attr-defined]
2728
return cast(X, fn)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
NDArray.var("x") + NDArray.var("y")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
NDArray.var("x")[IndexKey.int((NDArray.var("x").shape + TupleInt.from_vec(Vec[Int](Int(1), Int(2))))[Int(100)])]

python/tests/test_array_api.py

Lines changed: 29 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
# mypy: disable-error-code="empty-body"
2-
import ast
32
import inspect
43
from collections.abc import Callable
54
from itertools import product
65
from pathlib import Path
76
from types import FunctionType
8-
from typing import Any
97

108
import numba
119
import pytest
@@ -306,42 +304,6 @@ def run_lda(x, y):
306304
X_np, y_np = (iris.data, iris.target)
307305

308306

309-
def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any:
310-
"""
311-
Load a python snapshot, evaling the code, and returning the `var` defined in it.
312-
313-
If no var is provided, then return the last expression.
314-
"""
315-
path = Path(__file__).parent / "__snapshots__" / "test_array_api" / f"TestLDA.{fn.__name__}.py"
316-
contents = path.read_text()
317-
318-
contents = "import numpy as np\nfrom egglog.exp.array_api import *\n" + contents
319-
globals: dict[str, Any] = {}
320-
if var is None:
321-
# exec once as a full statement
322-
exec(contents, globals)
323-
# Eval the last statement
324-
last_expr = ast.unparse(ast.parse(contents).body[-1])
325-
return eval(last_expr, globals)
326-
exec(contents, globals)
327-
return globals[var]
328-
329-
330-
def lda(X: NDArray, y: NDArray):
331-
assume_dtype(X, X_np.dtype)
332-
assume_shape(X, X_np.shape)
333-
assume_isfinite(X)
334-
335-
assume_dtype(y, y_np.dtype)
336-
assume_shape(y, y_np.shape)
337-
assume_value_one_of(y, tuple(map(int, np.unique(y_np))))
338-
return run_lda(X, y)
339-
340-
341-
def lda_filled():
342-
return lda(NDArray.var("X"), NDArray.var("y"))
343-
344-
345307
@pytest.mark.parametrize(
346308
"program",
347309
[
@@ -364,80 +326,46 @@ def test_program_compile(program: Program, snapshot_py):
364326
assert "\n".join([*statements.split("\n"), expr]) == snapshot_py(name="code")
365327

366328

329+
def lda(X: NDArray, y: NDArray):
330+
assume_dtype(X, X_np.dtype)
331+
assume_shape(X, X_np.shape)
332+
assume_isfinite(X)
333+
334+
assume_dtype(y, y_np.dtype)
335+
assume_shape(y, y_np.shape)
336+
assume_value_one_of(y, tuple(map(int, np.unique(y_np))))
337+
return run_lda(X, y)
338+
339+
367340
@pytest.mark.parametrize(
368341
"program",
369342
[
370343
pytest.param(lambda x, y: x + y, id="add"),
371344
pytest.param(lambda x, y: x[(x.shape + TupleInt.from_vec((1, 2)))[100]], id="tuple"),
345+
pytest.param(lda, id="lda"),
372346
],
373347
)
374-
def test_jit(program, snapshot_py):
375-
jitted = jit(program)
348+
def test_jit(program, snapshot_py, benchmark):
349+
jitted = benchmark(jit, program)
350+
assert str(jitted.initial_expr) == snapshot_py(name="initial_expr")
376351
assert str(jitted.expr) == snapshot_py(name="expr")
377352
assert inspect.getsource(jitted) == snapshot_py(name="code")
378353

379354

380-
@pytest.mark.benchmark(min_rounds=3)
381-
class TestLDA:
382-
"""
383-
Incrementally benchmark each part of the LDA to see how long it takes to run.
384-
"""
385-
386-
def test_trace(self, snapshot_py, benchmark):
387-
@benchmark
388-
def X_r2():
389-
with EGraph().set_current():
390-
return lda_filled()
391-
392-
res = str(X_r2)
393-
assert res == snapshot_py
394-
395-
def test_optimize(self, snapshot_py, benchmark):
396-
egraph = EGraph()
397-
with egraph.set_current():
398-
expr = lda_filled()
399-
simplified = benchmark(egraph.simplify, expr, array_api_numba_schedule)
400-
401-
assert str(simplified) == snapshot_py
402-
403-
# @pytest.mark.xfail(reason="Original source is not working")
404-
# def test_source(self, snapshot_py, benchmark):
405-
# egraph = EGraph()
406-
# expr = trace_lda(egraph)
407-
# assert benchmark(load_source, expr, egraph) == snapshot_py
408-
409-
def test_source_optimized(self, snapshot_py, benchmark):
410-
egraph = EGraph()
411-
with egraph.set_current():
412-
expr = lda_filled()
413-
optimized_expr = egraph.simplify(expr, array_api_numba_schedule)
414-
415-
@benchmark
416-
def py_object():
417-
fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y"))
418-
return try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object)
419-
420-
assert np.allclose(py_object(X_np, y_np), run_lda(X_np, y_np))
421-
assert inspect.getsource(py_object) == snapshot_py
422-
423-
@pytest.mark.parametrize(
424-
"fn_thunk",
425-
[
426-
pytest.param(lambda: LinearDiscriminantAnalysis(n_components=2).fit_transform, id="base"),
427-
pytest.param(lambda: run_lda, id="array_api"),
428-
pytest.param(lambda: _load_py_snapshot(TestLDA.test_source_optimized, "__fn"), id="array_api-optimized"),
429-
pytest.param(
430-
lambda: numba.njit(_load_py_snapshot(TestLDA.test_source_optimized, "__fn")),
431-
id="array_api-optimized-numba",
432-
),
433-
pytest.param(lambda: jit(lda), id="array_api-jit"),
434-
],
435-
)
436-
def test_execution(self, fn_thunk, benchmark):
437-
fn = fn_thunk()
438-
# warmup once for numba
439-
assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np), rtol=1e-03)
440-
benchmark(fn, X_np, y_np)
355+
@pytest.mark.parametrize(
356+
"fn_thunk",
357+
[
358+
pytest.param(lambda: LinearDiscriminantAnalysis(n_components=2).fit_transform, id="base"),
359+
pytest.param(lambda: run_lda, id="array_api"),
360+
pytest.param(lambda: jit(lda), id="array_api-optimized"),
361+
pytest.param(lambda: numba.njit(jit(lda)), id="array_api-optimized-numba"),
362+
],
363+
)
364+
def test_run_lda(fn_thunk, benchmark):
365+
fn = fn_thunk()
366+
# warmup once for numba
367+
assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np), rtol=1e-03)
368+
benchmark(fn, X_np, y_np)
441369

442370

443371
# if calling as script, print out egglog source for test

0 commit comments

Comments
 (0)