1
1
# mypy: disable-error-code="empty-body"
2
- import ast
3
2
import inspect
4
3
from collections .abc import Callable
5
4
from itertools import product
6
5
from pathlib import Path
7
6
from types import FunctionType
8
- from typing import Any
9
7
10
8
import numba
11
9
import pytest
@@ -306,42 +304,6 @@ def run_lda(x, y):
306
304
X_np , y_np = (iris .data , iris .target )
307
305
308
306
309
- def _load_py_snapshot (fn : Callable , var : str | None = None ) -> Any :
310
- """
311
- Load a python snapshot, evaling the code, and returning the `var` defined in it.
312
-
313
- If no var is provided, then return the last expression.
314
- """
315
- path = Path (__file__ ).parent / "__snapshots__" / "test_array_api" / f"TestLDA.{ fn .__name__ } .py"
316
- contents = path .read_text ()
317
-
318
- contents = "import numpy as np\n from egglog.exp.array_api import *\n " + contents
319
- globals : dict [str , Any ] = {}
320
- if var is None :
321
- # exec once as a full statement
322
- exec (contents , globals )
323
- # Eval the last statement
324
- last_expr = ast .unparse (ast .parse (contents ).body [- 1 ])
325
- return eval (last_expr , globals )
326
- exec (contents , globals )
327
- return globals [var ]
328
-
329
-
330
- def lda (X : NDArray , y : NDArray ):
331
- assume_dtype (X , X_np .dtype )
332
- assume_shape (X , X_np .shape )
333
- assume_isfinite (X )
334
-
335
- assume_dtype (y , y_np .dtype )
336
- assume_shape (y , y_np .shape )
337
- assume_value_one_of (y , tuple (map (int , np .unique (y_np ))))
338
- return run_lda (X , y )
339
-
340
-
341
- def lda_filled ():
342
- return lda (NDArray .var ("X" ), NDArray .var ("y" ))
343
-
344
-
345
307
@pytest .mark .parametrize (
346
308
"program" ,
347
309
[
@@ -364,80 +326,46 @@ def test_program_compile(program: Program, snapshot_py):
364
326
assert "\n " .join ([* statements .split ("\n " ), expr ]) == snapshot_py (name = "code" )
365
327
366
328
329
+ def lda (X : NDArray , y : NDArray ):
330
+ assume_dtype (X , X_np .dtype )
331
+ assume_shape (X , X_np .shape )
332
+ assume_isfinite (X )
333
+
334
+ assume_dtype (y , y_np .dtype )
335
+ assume_shape (y , y_np .shape )
336
+ assume_value_one_of (y , tuple (map (int , np .unique (y_np ))))
337
+ return run_lda (X , y )
338
+
339
+
367
340
@pytest .mark .parametrize (
368
341
"program" ,
369
342
[
370
343
pytest .param (lambda x , y : x + y , id = "add" ),
371
344
pytest .param (lambda x , y : x [(x .shape + TupleInt .from_vec ((1 , 2 )))[100 ]], id = "tuple" ),
345
+ pytest .param (lda , id = "lda" ),
372
346
],
373
347
)
374
- def test_jit (program , snapshot_py ):
375
- jitted = jit (program )
348
+ def test_jit (program , snapshot_py , benchmark ):
349
+ jitted = benchmark (jit , program )
350
+ assert str (jitted .initial_expr ) == snapshot_py (name = "initial_expr" )
376
351
assert str (jitted .expr ) == snapshot_py (name = "expr" )
377
352
assert inspect .getsource (jitted ) == snapshot_py (name = "code" )
378
353
379
354
380
- @pytest .mark .benchmark (min_rounds = 3 )
381
- class TestLDA :
382
- """
383
- Incrementally benchmark each part of the LDA to see how long it takes to run.
384
- """
385
-
386
- def test_trace (self , snapshot_py , benchmark ):
387
- @benchmark
388
- def X_r2 ():
389
- with EGraph ().set_current ():
390
- return lda_filled ()
391
-
392
- res = str (X_r2 )
393
- assert res == snapshot_py
394
-
395
- def test_optimize (self , snapshot_py , benchmark ):
396
- egraph = EGraph ()
397
- with egraph .set_current ():
398
- expr = lda_filled ()
399
- simplified = benchmark (egraph .simplify , expr , array_api_numba_schedule )
400
-
401
- assert str (simplified ) == snapshot_py
402
-
403
- # @pytest.mark.xfail(reason="Original source is not working")
404
- # def test_source(self, snapshot_py, benchmark):
405
- # egraph = EGraph()
406
- # expr = trace_lda(egraph)
407
- # assert benchmark(load_source, expr, egraph) == snapshot_py
408
-
409
- def test_source_optimized (self , snapshot_py , benchmark ):
410
- egraph = EGraph ()
411
- with egraph .set_current ():
412
- expr = lda_filled ()
413
- optimized_expr = egraph .simplify (expr , array_api_numba_schedule )
414
-
415
- @benchmark
416
- def py_object ():
417
- fn_program = ndarray_function_two (optimized_expr , NDArray .var ("X" ), NDArray .var ("y" ))
418
- return try_evaling (array_api_program_gen_schedule , fn_program , fn_program .as_py_object )
419
-
420
- assert np .allclose (py_object (X_np , y_np ), run_lda (X_np , y_np ))
421
- assert inspect .getsource (py_object ) == snapshot_py
422
-
423
- @pytest .mark .parametrize (
424
- "fn_thunk" ,
425
- [
426
- pytest .param (lambda : LinearDiscriminantAnalysis (n_components = 2 ).fit_transform , id = "base" ),
427
- pytest .param (lambda : run_lda , id = "array_api" ),
428
- pytest .param (lambda : _load_py_snapshot (TestLDA .test_source_optimized , "__fn" ), id = "array_api-optimized" ),
429
- pytest .param (
430
- lambda : numba .njit (_load_py_snapshot (TestLDA .test_source_optimized , "__fn" )),
431
- id = "array_api-optimized-numba" ,
432
- ),
433
- pytest .param (lambda : jit (lda ), id = "array_api-jit" ),
434
- ],
435
- )
436
- def test_execution (self , fn_thunk , benchmark ):
437
- fn = fn_thunk ()
438
- # warmup once for numba
439
- assert np .allclose (run_lda (X_np , y_np ), fn (X_np , y_np ), rtol = 1e-03 )
440
- benchmark (fn , X_np , y_np )
355
+ @pytest .mark .parametrize (
356
+ "fn_thunk" ,
357
+ [
358
+ pytest .param (lambda : LinearDiscriminantAnalysis (n_components = 2 ).fit_transform , id = "base" ),
359
+ pytest .param (lambda : run_lda , id = "array_api" ),
360
+ pytest .param (lambda : jit (lda ), id = "array_api-optimized" ),
361
+ pytest .param (lambda : numba .njit (jit (lda )), id = "array_api-optimized-numba" ),
362
+ ],
363
+ )
364
+ def test_run_lda (fn_thunk , benchmark ):
365
+ fn = fn_thunk ()
366
+ # warmup once for numba
367
+ assert np .allclose (run_lda (X_np , y_np ), fn (X_np , y_np ), rtol = 1e-03 )
368
+ benchmark (fn , X_np , y_np )
441
369
442
370
443
371
# if calling as script, print out egglog source for test
0 commit comments