@@ -225,29 +225,33 @@ def linalg_norm_v2(X: NDArrayLike, axis: TupleIntLike) -> NDArray:
225225 )
226226
227227
228- def linalg_val (linalg_fn : Callable [[NDArray , TupleIntLike ], NDArray ]) -> NDArray :
229- X = NDArray .var ("X" )
228+ def linalg_val (X : NDArray , linalg_fn : Callable [[NDArray , TupleIntLike ], NDArray ]) -> NDArray :
230229 assume_shape (X , (3 , 2 , 3 , 4 ))
231230 return linalg_fn (X , (0 , 1 ))
232231
233232
234233class TestLoopNest :
235234 @pytest .mark .parametrize ("linalg_fn" , [linalg_norm , linalg_norm_v2 ])
236235 def test_shape (self , linalg_fn ):
237- check_eq (linalg_val (linalg_fn ).shape , TupleInt .from_vec ((3 , 4 )), array_api_schedule )
236+ X = np .random .random ((3 , 2 , 3 , 4 ))
237+ expect = np .linalg .norm (X , axis = (0 , 1 ))
238+ assert expect .shape == (3 , 4 )
239+
240+ check_eq (linalg_val (constant ("X" , NDArray ), linalg_fn ).shape , TupleInt .from_vec ((3 , 4 )), array_api_schedule )
238241
239242 @pytest .mark .parametrize ("linalg_fn" , [linalg_norm , linalg_norm_v2 ])
240- def test_index (self , linalg_fn ):
243+ def test_abstract_index (self , linalg_fn ):
241244 i = constant ("i" , Int )
242245 j = constant ("j" , Int )
243- idxed = linalg_val (linalg_fn ).index ((i , j ))
244- _NDArray_1 = NDArray .var ("X" )
245- _Value_1 = _NDArray_1 .index (TupleInt .from_vec (Vec [Int ](Int (0 ), Int (0 ), i , j )))
246- _Value_2 = _NDArray_1 .index (TupleInt .from_vec (Vec [Int ](Int (0 ), Int (1 ), i , j )))
247- _Value_3 = _NDArray_1 .index (TupleInt .from_vec (Vec [Int ](Int (1 ), Int (0 ), i , j )))
248- _Value_4 = _NDArray_1 .index (TupleInt .from_vec (Vec [Int ](Int (1 ), Int (1 ), i , j )))
249- _Value_5 = _NDArray_1 .index (TupleInt .from_vec (Vec [Int ](Int (2 ), Int (0 ), i , j )))
250- _Value_6 = _NDArray_1 .index (TupleInt .from_vec (Vec [Int ](Int (2 ), Int (1 ), i , j )))
246+ X = constant ("X" , NDArray )
247+ idxed = linalg_val (X , linalg_fn ).index ((i , j ))
248+
249+ _Value_1 = X .index (TupleInt .from_vec (Vec [Int ](Int (0 ), Int (0 ), i , j )))
250+ _Value_2 = X .index (TupleInt .from_vec (Vec [Int ](Int (0 ), Int (1 ), i , j )))
251+ _Value_3 = X .index (TupleInt .from_vec (Vec [Int ](Int (1 ), Int (0 ), i , j )))
252+ _Value_4 = X .index (TupleInt .from_vec (Vec [Int ](Int (1 ), Int (1 ), i , j )))
253+ _Value_5 = X .index (TupleInt .from_vec (Vec [Int ](Int (2 ), Int (0 ), i , j )))
254+ _Value_6 = X .index (TupleInt .from_vec (Vec [Int ](Int (2 ), Int (1 ), i , j )))
251255 res = (
252256 (
253257 (
@@ -263,6 +267,36 @@ def test_index(self, linalg_fn):
263267 ).sqrt ()
264268 check_eq (idxed , res , array_api_schedule )
265269
270+ def test_index_codegen (self , snapshot_py ):
271+ X = NDArray .var ("X" )
272+ i = Int .var ("i" )
273+ j = Int .var ("j" )
274+ idxed = linalg_val (X , linalg_norm_v2 ).index ((i , j ))
275+ simplified_index = simplify (idxed , array_api_schedule )
276+ assert str (simplified_index ) == snapshot_py (name = "expr" )
277+
278+ res = EvalProgram (
279+ value_program (simplified_index ).function_three (ndarray_program (X ), int_program (i ), int_program (j )),
280+ {"np" : np },
281+ )
282+ egraph = EGraph ()
283+ egraph .register (res )
284+ egraph .run (array_api_program_gen_schedule )
285+ print (
286+ egraph .extract (
287+ value_program (simplified_index ).function_three (ndarray_program (X ), int_program (i ), int_program (j ))
288+ )
289+ )
290+ # egraph.display(split_primitive_outputs=True, n_inline_leaves=3, split_functions=[TupleInt.EMPTY, TupleInt.append, Int])
291+ assert egraph .eval (res .statements ) == snapshot_py (name = "code" )
292+
293+ fn_value = egraph .eval (res .py_object )
294+ X = np .random .random ((3 , 2 , 3 , 4 ))
295+ expect = np .linalg .norm (X , axis = (0 , 1 ))
296+
297+ for idxs in np .ndindex (* expect .shape ):
298+ assert np .allclose (fn_value (X , * idxs ), expect [idxs ], rtol = 1e-03 )
299+
266300
267301# This test happens in different steps. Each will be benchmarked and saved as a snapshot.
268302# The next step will load the old snapshot and run their test on it.
0 commit comments