@@ -380,20 +380,21 @@ def test_Shape(x, i):
380380
381381
382382@pytest .mark .parametrize (
383- "x, axis, kind, exc " ,
383+ "x" ,
384384 [
385- [[3 , 2 , 1 ], None , "quicksort" , None ],
386- [[], None , "quicksort" , None ],
387- [[[3 , 2 , 1 ], [5 , 6 , 7 ]], None , "quicksort" , None ],
388- [[3 , 2 , 1 ], None , "mergesort" , UserWarning ],
389- [[3 , 2 , 1 ], None , "heapsort" , UserWarning ],
390- [[3 , 2 , 1 ], None , "stable" , UserWarning ],
391- [[[3 , 2 , 1 ], [5 , 6 , 7 ]], 0 , "quicksort" , None ],
392- [[[3 , 2 , 1 ], [5 , 6 , 7 ]], 1 , "quicksort" , None ],
393- [[[3 , 2 , 1 ], [5 , 6 , 7 ]], - 1 , "quicksort" , None ],
394- [[3 , 2 , 1 ], 0 , "quicksort" , None ],
395- [np .random .randint (0 , 100 , (40 , 40 , 40 , 40 )), 3 , "quicksort" , None ],
396- [[3 , 2 , 1 ], - 5 , "quicksort" , np .exceptions .AxisError ],
385+ [], # Empty list
386+ [3 , 2 , 1 ], # Simple list
387+ np .random .randint (0 , 10 , (3 , 2 , 3 , 4 , 4 )), # Multi-dimensional array
388+ ],
389+ )
390+ @pytest .mark .parametrize ("axis" , [0 , - 1 , None ])
391+ @pytest .mark .parametrize (
392+ ("kind" , "exc" ),
393+ [
394+ ["quicksort" , None ],
395+ ["mergesort" , UserWarning ],
396+ ["heapsort" , UserWarning ],
397+ ["stable" , UserWarning ],
397398 ],
398399)
399400def test_Sort (x , axis , kind , exc ):
@@ -402,36 +403,35 @@ def test_Sort(x, axis, kind, exc):
402403 else :
403404 g = SortOp (kind )(pt .as_tensor_variable (x ))
404405
405- cm = (
406- contextlib .suppress ()
407- if not exc
408- else pytest .warns (exc )
409- if isinstance (exc , Warning )
410- else pytest .raises (exc )
411- )
406+ cm = contextlib .suppress () if not exc else pytest .warns (exc )
412407
413408 with cm :
414409 compare_numba_and_py ([], [g ], [])
415410
416411
417412@pytest .mark .parametrize (
418- "x, axis, kind, exc" ,
413+ "x" ,
414+ [
415+ [], # Empty list
416+ [3 , 2 , 1 ], # Simple list
417+ None , # Multi-dimensional array (see below)
418+ ],
419+ )
420+ @pytest .mark .parametrize ("axis" , [0 , - 1 , None ])
421+ @pytest .mark .parametrize (
422+ ("kind" , "exc" ),
419423 [
420- [[3 , 2 , 1 ], None , "quicksort" , None ],
421- [[], None , "quicksort" , None ],
422- [[[3 , 2 , 1 ], [5 , 6 , 7 ]], None , "quicksort" , None ],
423- [[3 , 2 , 1 ], None , "heapsort" , UserWarning ],
424- [[3 , 2 , 1 ], None , "stable" , UserWarning ],
425- [[[3 , 2 , 1 ], [5 , 6 , 7 ]], 0 , "quicksort" , None ],
426- [[[3 , 2 , 1 ], [5 , 6 , 7 ]], None , "quicksort" , None ],
427- [[[3 , 2 , 1 ], [5 , 6 , 7 ]], 1 , "quicksort" , None ],
428- [[[3 , 2 , 1 ], [5 , 6 , 7 ]], - 1 , "quicksort" , None ],
429- [[3 , 2 , 1 ], 0 , "quicksort" , None ],
430- [np .random .randint (0 , 10 , (3 , 2 , 3 )), 1 , "quicksort" , None ],
431- [np .random .randint (0 , 10 , (3 , 2 , 3 , 4 , 4 )), 2 , "quicksort" , None ],
424+ ["quicksort" , None ],
425+ ["heapsort" , None ],
426+ ["stable" , UserWarning ],
432427 ],
433428)
434429def test_ArgSort (x , axis , kind , exc ):
430+ if x is None :
431+ x = np .arange (5 * 5 * 5 * 5 )
432+ np .random .shuffle (x )
433+ x = np .reshape (x , (5 , 5 , 5 , 5 ))
434+
435435 if axis :
436436 g = ArgSortOp (kind )(pt .as_tensor_variable (x ), axis )
437437 else :
0 commit comments