11# mypy: disable-error-code="empty-body"
2- import ast
32import inspect
43from collections .abc import Callable
54from itertools import product
65from pathlib import Path
76from types import FunctionType
8- from typing import Any
97
108import numba
119import pytest
@@ -306,42 +304,6 @@ def run_lda(x, y):
306304X_np , y_np = (iris .data , iris .target )
307305
308306
309- def _load_py_snapshot (fn : Callable , var : str | None = None ) -> Any :
310- """
311- Load a python snapshot, evaling the code, and returning the `var` defined in it.
312-
313- If no var is provided, then return the last expression.
314- """
315- path = Path (__file__ ).parent / "__snapshots__" / "test_array_api" / f"TestLDA.{ fn .__name__ } .py"
316- contents = path .read_text ()
317-
318- contents = "import numpy as np\n from egglog.exp.array_api import *\n " + contents
319- globals : dict [str , Any ] = {}
320- if var is None :
321- # exec once as a full statement
322- exec (contents , globals )
323- # Eval the last statement
324- last_expr = ast .unparse (ast .parse (contents ).body [- 1 ])
325- return eval (last_expr , globals )
326- exec (contents , globals )
327- return globals [var ]
328-
329-
330- def lda (X : NDArray , y : NDArray ):
331- assume_dtype (X , X_np .dtype )
332- assume_shape (X , X_np .shape )
333- assume_isfinite (X )
334-
335- assume_dtype (y , y_np .dtype )
336- assume_shape (y , y_np .shape )
337- assume_value_one_of (y , tuple (map (int , np .unique (y_np ))))
338- return run_lda (X , y )
339-
340-
341- def lda_filled ():
342- return lda (NDArray .var ("X" ), NDArray .var ("y" ))
343-
344-
345307@pytest .mark .parametrize (
346308 "program" ,
347309 [
@@ -364,80 +326,46 @@ def test_program_compile(program: Program, snapshot_py):
364326 assert "\n " .join ([* statements .split ("\n " ), expr ]) == snapshot_py (name = "code" )
365327
366328
329+ def lda (X : NDArray , y : NDArray ):
330+ assume_dtype (X , X_np .dtype )
331+ assume_shape (X , X_np .shape )
332+ assume_isfinite (X )
333+
334+ assume_dtype (y , y_np .dtype )
335+ assume_shape (y , y_np .shape )
336+ assume_value_one_of (y , tuple (map (int , np .unique (y_np ))))
337+ return run_lda (X , y )
338+
339+
367340@pytest .mark .parametrize (
368341 "program" ,
369342 [
370343 pytest .param (lambda x , y : x + y , id = "add" ),
371344 pytest .param (lambda x , y : x [(x .shape + TupleInt .from_vec ((1 , 2 )))[100 ]], id = "tuple" ),
345+ pytest .param (lda , id = "lda" ),
372346 ],
373347)
374- def test_jit (program , snapshot_py ):
375- jitted = jit (program )
348+ def test_jit (program , snapshot_py , benchmark ):
349+ jitted = benchmark (jit , program )
350+ assert str (jitted .initial_expr ) == snapshot_py (name = "initial_expr" )
376351 assert str (jitted .expr ) == snapshot_py (name = "expr" )
377352 assert inspect .getsource (jitted ) == snapshot_py (name = "code" )
378353
379354
380- @pytest .mark .benchmark (min_rounds = 3 )
381- class TestLDA :
382- """
383- Incrementally benchmark each part of the LDA to see how long it takes to run.
384- """
385-
386- def test_trace (self , snapshot_py , benchmark ):
387- @benchmark
388- def X_r2 ():
389- with EGraph ().set_current ():
390- return lda_filled ()
391-
392- res = str (X_r2 )
393- assert res == snapshot_py
394-
395- def test_optimize (self , snapshot_py , benchmark ):
396- egraph = EGraph ()
397- with egraph .set_current ():
398- expr = lda_filled ()
399- simplified = benchmark (egraph .simplify , expr , array_api_numba_schedule )
400-
401- assert str (simplified ) == snapshot_py
402-
403- # @pytest.mark.xfail(reason="Original source is not working")
404- # def test_source(self, snapshot_py, benchmark):
405- # egraph = EGraph()
406- # expr = trace_lda(egraph)
407- # assert benchmark(load_source, expr, egraph) == snapshot_py
408-
409- def test_source_optimized (self , snapshot_py , benchmark ):
410- egraph = EGraph ()
411- with egraph .set_current ():
412- expr = lda_filled ()
413- optimized_expr = egraph .simplify (expr , array_api_numba_schedule )
414-
415- @benchmark
416- def py_object ():
417- fn_program = ndarray_function_two (optimized_expr , NDArray .var ("X" ), NDArray .var ("y" ))
418- return try_evaling (array_api_program_gen_schedule , fn_program , fn_program .as_py_object )
419-
420- assert np .allclose (py_object (X_np , y_np ), run_lda (X_np , y_np ))
421- assert inspect .getsource (py_object ) == snapshot_py
422-
423- @pytest .mark .parametrize (
424- "fn_thunk" ,
425- [
426- pytest .param (lambda : LinearDiscriminantAnalysis (n_components = 2 ).fit_transform , id = "base" ),
427- pytest .param (lambda : run_lda , id = "array_api" ),
428- pytest .param (lambda : _load_py_snapshot (TestLDA .test_source_optimized , "__fn" ), id = "array_api-optimized" ),
429- pytest .param (
430- lambda : numba .njit (_load_py_snapshot (TestLDA .test_source_optimized , "__fn" )),
431- id = "array_api-optimized-numba" ,
432- ),
433- pytest .param (lambda : jit (lda ), id = "array_api-jit" ),
434- ],
435- )
436- def test_execution (self , fn_thunk , benchmark ):
437- fn = fn_thunk ()
438- # warmup once for numba
439- assert np .allclose (run_lda (X_np , y_np ), fn (X_np , y_np ), rtol = 1e-03 )
440- benchmark (fn , X_np , y_np )
355+ @pytest .mark .parametrize (
356+ "fn_thunk" ,
357+ [
358+ pytest .param (lambda : LinearDiscriminantAnalysis (n_components = 2 ).fit_transform , id = "base" ),
359+ pytest .param (lambda : run_lda , id = "array_api" ),
360+ pytest .param (lambda : jit (lda ), id = "array_api-optimized" ),
361+ pytest .param (lambda : numba .njit (jit (lda )), id = "array_api-optimized-numba" ),
362+ ],
363+ )
364+ def test_run_lda (fn_thunk , benchmark ):
365+ fn = fn_thunk ()
366+ # warmup once for numba
367+ assert np .allclose (run_lda (X_np , y_np ), fn (X_np , y_np ), rtol = 1e-03 )
368+ benchmark (fn , X_np , y_np )
441369
442370
443371# if calling as script, print out egglog source for test
0 commit comments