99from sklearn .discriminant_analysis import LinearDiscriminantAnalysis
1010
1111from egglog .exp .array_api import *
12+ from egglog .exp .array_api_jit import jit
1213from egglog .exp .array_api_numba import array_api_numba_schedule
1314from egglog .exp .array_api_program_gen import *
1415
@@ -103,51 +104,69 @@ def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any:
103104 return globals [var ]
104105
105106
106- def load_source (expr , egraph : EGraph ):
107- with egraph :
108- fn_program = egraph .let ( "fn_program" , ndarray_function_two ( expr , NDArray . var ( "X" ), NDArray . var ( "y" )) )
109- egraph . run ( array_api_program_gen_schedule )
110- return egraph .eval (egraph .extract (fn_program .statements ))
107+ def load_source (fn_program : EvalProgram , egraph : EGraph ):
108+ egraph . register ( fn_program )
109+ egraph .run ( array_api_program_gen_schedule )
110+ # dp the needed pieces in here for benchmarking
111+ return egraph .eval (egraph .extract (fn_program .py_object ))
111112
112113
113- def trace_lda (egraph : EGraph ):
114- X_arr = NDArray .var ("X" )
115- assume_dtype (X_arr , X_np .dtype )
116- assume_shape (X_arr , X_np .shape )
117- assume_isfinite (X_arr )
114+ def lda (X , y ):
115+ assume_dtype (X , X_np .dtype )
116+ assume_shape (X , X_np .shape )
117+ assume_isfinite (X )
118118
119- y_arr = NDArray . var ( "y" )
120- assume_dtype ( y_arr , y_np .dtype )
121- assume_shape ( y_arr , y_np . shape )
122- assume_value_one_of ( y_arr , tuple ( map ( int , np . unique ( y_np )))) # type: ignore[arg-type]
119+ assume_dtype ( y , y_np . dtype )
120+ assume_shape ( y , y_np .shape )
121+ assume_value_one_of ( y , tuple ( map ( int , np . unique ( y_np )))) # type: ignore[arg-type]
122+ return run_lda ( X , y )
123123
124- with egraph :
125- return run_lda (X_arr , y_arr )
124+
125+ def simplify_lda (egraph : EGraph , expr : NDArray ) -> NDArray :
126+ egraph .register (expr )
127+ egraph .run (array_api_numba_schedule )
128+ return egraph .extract (expr )
126129
127130
128131@pytest .mark .benchmark (min_rounds = 3 )
129132class TestLDA :
133+ """
134+ Incrementally benchmark each part of the LDA to see how long it takes to run.
135+ """
136+
130137 def test_trace (self , snapshot_py , benchmark ):
131- X_r2 = benchmark (trace_lda , EGraph ())
138+ X = NDArray .var ("X" )
139+ y = NDArray .var ("y" )
140+ with EGraph ():
141+ X_r2 = benchmark (lda , X , y )
132142 assert str (X_r2 ) == snapshot_py
133143
134144 def test_optimize (self , snapshot_py , benchmark ):
135145 egraph = EGraph ()
136- expr = trace_lda (egraph )
137- simplified = benchmark (egraph .simplify , expr , array_api_numba_schedule )
146+ X = NDArray .var ("X" )
147+ y = NDArray .var ("y" )
148+ with egraph :
149+ expr = lda (X , y )
150+ simplified = benchmark (simplify_lda , egraph , expr )
138151 assert str (simplified ) == snapshot_py
139152
140- @pytest .mark .xfail (reason = "Original source is not working" )
141- def test_source (self , snapshot_py , benchmark ):
142- egraph = EGraph ()
143- expr = trace_lda (egraph )
144- assert benchmark (load_source , expr , egraph ) == snapshot_py
153+ # @pytest.mark.xfail(reason="Original source is not working")
154+ # def test_source(self, snapshot_py, benchmark):
155+ # egraph = EGraph()
156+ # expr = trace_lda(egraph)
157+ # assert benchmark(load_source, expr, egraph) == snapshot_py
145158
146159 def test_source_optimized (self , snapshot_py , benchmark ):
147160 egraph = EGraph ()
148- expr = trace_lda (egraph )
149- optimized_expr = egraph .simplify (expr , array_api_numba_schedule )
150- assert benchmark (load_source , optimized_expr , egraph ) == snapshot_py
161+ X = NDArray .var ("X" )
162+ y = NDArray .var ("y" )
163+ with egraph :
164+ expr = lda (X , y )
165+ optimized_expr = simplify_lda (egraph , expr )
166+ fn_program = ndarray_function_two (optimized_expr , X , y )
167+ py_object = benchmark (load_source , fn_program , egraph )
168+ assert np .allclose (py_object (X_np , y_np ), res_np )
169+ assert egraph .eval (fn_program .statements ) == snapshot_py
151170
152171 @pytest .mark .parametrize (
153172 "fn" ,
@@ -156,9 +175,29 @@ def test_source_optimized(self, snapshot_py, benchmark):
156175 pytest .param (run_lda , id = "array_api" ),
157176 pytest .param (_load_py_snapshot (test_source_optimized , "__fn" ), id = "array_api-optimized" ),
158177 pytest .param (numba .njit (_load_py_snapshot (test_source_optimized , "__fn" )), id = "array_api-optimized-numba" ),
178+ pytest .param (jit (lda ), id = "array_api-jit" ),
159179 ],
160180 )
161181 def test_execution (self , fn , benchmark ):
162182 # warmup once for numba
163183 assert np .allclose (res_np , fn (X_np , y_np ))
164184 benchmark (fn , X_np , y_np )
185+
186+
187+ # if calling as script, print out egglog source for test
188+ # similar to jit, but don't include pyobject parts so it works in vanilla egglog
189+ if __name__ == "__main__" :
190+ print ("Generating egglog source for test" )
191+ egraph = EGraph (save_egglog_string = True )
192+ X_ = NDArray .var ("X" )
193+ y_ = NDArray .var ("y" )
194+ with egraph :
195+ expr = lda (X_ , y_ )
196+ optimized_expr = egraph .simplify (expr , array_api_numba_schedule )
197+ fn_program = ndarray_function_two_program (optimized_expr , X_ , y_ )
198+ egraph .register (fn_program .compile ())
199+ egraph .run (array_api_program_gen_ruleset .saturate () + program_gen_ruleset .saturate ())
200+ egraph .extract (fn_program .statements )
201+ name = "python.egg"
202+ print ("Saving to" , name )
203+ Path (name ).write_text (egraph .as_egglog_string )
0 commit comments