@@ -380,45 +380,59 @@ def test_Shape(x, i):
380380
381381
382382@pytest .mark .parametrize (
383- "kind, exc" ,
383+ "x, axis, kind, exc" ,
384384 [
385- ["quicksort" , None ],
386- ["mergesort" , UserWarning ],
387- ["heapsort" , UserWarning ],
388- ["stable" , UserWarning ],
385+ [[3 , 2 , 1 ], None , "quicksort" , None ],
386+ [[], None , "quicksort" , None ],
387+ [[[3 , 2 , 1 ], [5 , 6 , 7 ]], None , "quicksort" , None ],
388+ [[3 , 2 , 1 ], None , "mergesort" , UserWarning ],
389+ [[3 , 2 , 1 ], None , "heapsort" , UserWarning ],
390+ [[3 , 2 , 1 ], None , "stable" , UserWarning ],
391+ [[[3 , 2 , 1 ], [5 , 6 , 7 ]], 0 , "quicksort" , None ],
392+ [[[3 , 2 , 1 ], [5 , 6 , 7 ]], 1 , "quicksort" , None ],
393+ [[[3 , 2 , 1 ], [5 , 6 , 7 ]], - 1 , "quicksort" , None ],
394+ [[3 , 2 , 1 ], 0 , "quicksort" , None ],
395+ [np .random .randint (0 , 100 , (40 , 40 , 40 , 40 )), 3 , "quicksort" , None ],
389396 ],
390397)
391- def test_Sort (kind , exc ):
392- x = [5 , 4 , 3 , 2 , 1 ]
398+ def test_Sort (x , axis , kind , exc ):
399+ if axis :
400+ g = SortOp (kind )(pt .as_tensor_variable (x ), axis )
401+ else :
402+ g = SortOp (kind )(pt .as_tensor_variable (x ))
393403
394- g = SortOp ( kind )( pt . as_tensor_variable ( x ) )
404+ cm = contextlib . suppress () if not exc else pytest . warns ( exc )
395405
396- if exc :
397- with pytest .warns (exc ):
398- compare_numba_and_py ([], [g ], [])
399- else :
406+ with cm :
400407 compare_numba_and_py ([], [g ], [])
401408
402- compare_numba_and_py ([], [g ], [])
403-
404409
405410@pytest .mark .parametrize (
406- "kind, exc" ,
411+ "x, axis, kind, exc" ,
407412 [
408- ["quicksort" , None ],
409- ["mergesort" , None ],
410- ["heapsort" , UserWarning ],
411- ["stable" , UserWarning ],
413+ [[3 , 2 , 1 ], None , "quicksort" , None ],
414+ [[], None , "quicksort" , None ],
415+ [[[3 , 2 , 1 ], [5 , 6 , 7 ]], None , "quicksort" , None ],
416+ [[3 , 2 , 1 ], None , "heapsort" , UserWarning ],
417+ [[3 , 2 , 1 ], None , "stable" , UserWarning ],
418+ [[[3 , 2 , 1 ], [5 , 6 , 7 ]], 0 , "quicksort" , None ],
419+ [[[3 , 2 , 1 ], [5 , 6 , 7 ]], None , "quicksort" , None ],
420+ [[[3 , 2 , 1 ], [5 , 6 , 7 ]], 1 , "quicksort" , None ],
421+ [[[3 , 2 , 1 ], [5 , 6 , 7 ]], - 1 , "quicksort" , None ],
422+ [[3 , 2 , 1 ], 0 , "quicksort" , None ],
423+ [np .random .randint (0 , 10 , (3 , 2 , 3 )), 1 , "quicksort" , None ],
424+ [np .random .randint (0 , 10 , (3 , 2 , 3 , 4 , 4 )), 2 , "quicksort" , None ],
412425 ],
413426)
414- def test_ArgSort (kind , exc ):
415- x = [5 , 4 , 3 , 2 , 1 ]
416- g = ArgSortOp (kind )(pt .as_tensor_variable (x ))
417-
418- if exc :
419- with pytest .warns (exc ):
420- compare_numba_and_py ([], [g ], [])
427+ def test_ArgSort (x , axis , kind , exc ):
428+ if axis :
429+ g = ArgSortOp (kind )(pt .as_tensor_variable (x ), axis )
421430 else :
431+ g = ArgSortOp (kind )(pt .as_tensor_variable (x ))
432+
433+ cm = contextlib .suppress () if not exc else pytest .warns (exc )
434+
435+ with cm :
422436 compare_numba_and_py ([], [g ], [])
423437
424438
0 commit comments