19
19
from egglog .exp .array_api_loopnest import *
20
20
from egglog .exp .array_api_numba import array_api_numba_schedule
21
21
from egglog .exp .array_api_program_gen import *
22
- from egglog .exp .program_gen import Program
22
+ from egglog .exp .program_gen import EvalProgram , Program
23
23
24
24
some_shape = constant ("some_shape" , TupleInt )
25
25
some_dtype = constant ("some_dtype" , DType )
@@ -327,33 +327,19 @@ def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any:
327
327
return globals [var ]
328
328
329
329
330
- def load_source (fn_program : EvalProgram , egraph : EGraph ):
331
- egraph .register (fn_program )
332
- egraph .run (array_api_program_gen_schedule )
333
- # dp the needed pieces in here for benchmarking
334
- try :
335
- return egraph .extract (fn_program .as_py_object ).eval ()
336
- except Exception as err :
337
- err .add_note (f"Failed to compile the program into a string: \n \n { egraph .extract (fn_program )} " )
338
- egraph .display (split_primitive_outputs = True , n_inline_leaves = 3 , split_functions = [Program ])
339
- raise
340
-
341
-
342
- def lda (X , y ):
330
+ def lda (X : NDArray , y : NDArray ):
343
331
assume_dtype (X , X_np .dtype )
344
332
assume_shape (X , X_np .shape )
345
333
assume_isfinite (X )
346
334
347
335
assume_dtype (y , y_np .dtype )
348
336
assume_shape (y , y_np .shape )
349
- assume_value_one_of (y , tuple (map (int , np .unique (y_np )))) # type: ignore[arg-type]
337
+ assume_value_one_of (y , tuple (map (int , np .unique (y_np ))))
350
338
return run_lda (X , y )
351
339
352
340
353
- def simplify_lda (egraph : EGraph , expr : NDArray ) -> NDArray :
354
- egraph .register (expr )
355
- egraph .run (array_api_numba_schedule )
356
- return egraph .extract (expr )
341
+ def lda_filled ():
342
+ return lda (NDArray .var ("X" ), NDArray .var ("y" ))
357
343
358
344
359
345
@pytest .mark .parametrize (
@@ -398,21 +384,20 @@ class TestLDA:
398
384
"""
399
385
400
386
def test_trace (self , snapshot_py , benchmark ):
401
- X = NDArray .var ("X" )
402
- y = NDArray .var ("y" )
403
- with EGraph ().set_current ():
404
- X_r2 = benchmark (lda , X , y )
387
+ @benchmark
388
+ def X_r2 ():
389
+ with EGraph ().set_current ():
390
+ return lda_filled ()
391
+
405
392
res = str (X_r2 )
406
- print (res )
407
393
assert res == snapshot_py
408
394
409
395
def test_optimize (self , snapshot_py , benchmark ):
410
396
egraph = EGraph ()
411
- X = NDArray .var ("X" )
412
- y = NDArray .var ("y" )
413
397
with egraph .set_current ():
414
- expr = lda (X , y )
415
- simplified = benchmark (simplify_lda , egraph , expr )
398
+ expr = lda_filled ()
399
+ simplified = benchmark (egraph .simplify , expr , array_api_numba_schedule )
400
+
416
401
assert str (simplified ) == snapshot_py
417
402
418
403
# @pytest.mark.xfail(reason="Original source is not working")
@@ -423,18 +408,17 @@ def test_optimize(self, snapshot_py, benchmark):
423
408
424
409
def test_source_optimized (self , snapshot_py , benchmark ):
425
410
egraph = EGraph ()
426
- X = NDArray .var ("X" )
427
- y = NDArray .var ("y" )
428
411
with egraph .set_current ():
429
- expr = lda (X , y )
430
- optimized_expr = simplify_lda (egraph , expr )
431
- egraph = EGraph ()
432
- fn_program = ndarray_function_two (optimized_expr , NDArray .var ("X" ), NDArray .var ("y" ))
433
- py_object = benchmark (load_source , fn_program , egraph )
412
+ expr = lda_filled ()
413
+ optimized_expr = egraph .simplify (expr , array_api_numba_schedule )
414
+
415
+ @benchmark
416
+ def py_object ():
417
+ fn_program = ndarray_function_two (optimized_expr , NDArray .var ("X" ), NDArray .var ("y" ))
418
+ return try_evaling (array_api_program_gen_schedule , fn_program , fn_program .as_py_object )
419
+
434
420
assert np .allclose (py_object (X_np , y_np ), run_lda (X_np , y_np ))
435
- with egraph .set_current ():
436
- fn_object = cast (FunctionType , fn_program .as_py_object .eval ())
437
- assert inspect .getsource (fn_object ) == snapshot_py
421
+ assert inspect .getsource (py_object ) == snapshot_py
438
422
439
423
@pytest .mark .parametrize (
440
424
"fn_thunk" ,
0 commit comments