88from sklearn import config_context , datasets
99from sklearn .discriminant_analysis import LinearDiscriminantAnalysis
1010
11+ from egglog .egraph import set_current_ruleset
1112from egglog .exp .array_api import *
1213from egglog .exp .array_api_jit import jit
14+ from egglog .exp .array_api_loopnest import *
1315from egglog .exp .array_api_numba import array_api_numba_schedule
1416from egglog .exp .array_api_program_gen import *
1517
@@ -68,6 +70,41 @@ def test_reshape_vec_noop():
6870 egraph .check (eq (res ).to (x ))
6971
7072
73+ def test_filter ():
74+ with set_current_ruleset (array_api_ruleset ):
75+ x = TupleInt .range (5 ).filter (lambda i : i < 2 ).length ()
76+ check_eq (x , Int (2 ), array_api_schedule )
77+
78+
79+ @function (ruleset = array_api_ruleset , subsume = True )
80+ def linalg_norm (X : NDArray , axis : TupleIntLike ) -> NDArray :
81+ # peel off the outer shape for result array
82+ outshape = ShapeAPI (X .shape ).deselect (axis ).to_tuple ()
83+ # get only the inner shape for reduction
84+ reduce_axis = ShapeAPI (X .shape ).select (axis ).to_tuple ()
85+
86+ return NDArray (
87+ outshape ,
88+ X .dtype ,
89+ lambda k : sqrt (
90+ LoopNestAPI .from_tuple (reduce_axis )
91+ .unwrap ()
92+ .fold (lambda carry , i : carry + real (conj (x := X [i + k ]) * x ), init = 0.0 )
93+ ).to_value (),
94+ )
95+
96+
97+ class TestLoopNest :
98+ def test_shape (self ):
99+ X = NDArray .var ("X" )
100+ assume_shape (X , (3 , 2 , 3 , 4 ))
101+ val = linalg_norm (X , (0 , 1 ))
102+
103+ check_eq (val .shape .length (), Int (2 ), array_api_schedule )
104+ check_eq (val .shape [0 ], Int (3 ), array_api_schedule )
105+ check_eq (val .shape [1 ], Int (4 ), array_api_schedule )
106+
107+
71108# This test happens in different steps. Each will be benchmarked and saved as a snapshot.
72109# The next step will load the old snapshot and run their test on it.
73110
@@ -80,7 +117,6 @@ def run_lda(x, y):
80117
81118iris = datasets .load_iris ()
82119X_np , y_np = (iris .data , iris .target )
83- res_np = run_lda (X_np , y_np )
84120
85121
86122def _load_py_snapshot (fn : Callable , var : str | None = None ) -> Any :
@@ -165,7 +201,7 @@ def test_source_optimized(self, snapshot_py, benchmark):
165201 optimized_expr = simplify_lda (egraph , expr )
166202 fn_program = ndarray_function_two (optimized_expr , NDArray .var ("X" ), NDArray .var ("y" ))
167203 py_object = benchmark (load_source , fn_program , egraph )
168- assert np .allclose (py_object (X_np , y_np ), res_np )
204+ assert np .allclose (py_object (X_np , y_np ), run_lda ( X_np , y_np ) )
169205 assert egraph .eval (fn_program .statements ) == snapshot_py
170206
171207 @pytest .mark .parametrize (
@@ -180,7 +216,7 @@ def test_source_optimized(self, snapshot_py, benchmark):
180216 )
181217 def test_execution (self , fn , benchmark ):
182218 # warmup once for numba
183- assert np .allclose (res_np , fn (X_np , y_np ))
219+ assert np .allclose (run_lda ( X_np , y_np ) , fn (X_np , y_np ))
184220 benchmark (fn , X_np , y_np )
185221
186222
0 commit comments