Skip to content

Commit 8726955

Browse files
author
Victor Garcia Reolid
committed
simplify tests
Signed-off-by: Victor Garcia Reolid <[email protected]>
1 parent daa6730 commit 8726955

File tree

2 files changed

+36
-37
lines changed

2 files changed

+36
-37
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -482,9 +482,8 @@ def argort_vec(X, axis):
482482

483483
kind = op.kind
484484

485-
if kind in ["quicksort", "mergesort"]:
486-
return argsort_f_kind(kind)
487-
else:
485+
if kind not in ["quicksort", "mergesort"]:
486+
kind = "quicksort"
488487
warnings.warn(
489488
(
490489
f'Numba function argsort doesn\'t support kind="{op.kind}"'
@@ -493,7 +492,7 @@ def argort_vec(X, axis):
493492
UserWarning,
494493
)
495494

496-
return argsort_f_kind("quicksort")
495+
return argsort_f_kind(kind)
497496

498497

499498
@numba.extending.intrinsic

tests/link/numba/test_basic.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -380,20 +380,21 @@ def test_Shape(x, i):
380380

381381

382382
@pytest.mark.parametrize(
383-
"x, axis, kind, exc",
383+
"x",
384384
[
385-
[[3, 2, 1], None, "quicksort", None],
386-
[[], None, "quicksort", None],
387-
[[[3, 2, 1], [5, 6, 7]], None, "quicksort", None],
388-
[[3, 2, 1], None, "mergesort", UserWarning],
389-
[[3, 2, 1], None, "heapsort", UserWarning],
390-
[[3, 2, 1], None, "stable", UserWarning],
391-
[[[3, 2, 1], [5, 6, 7]], 0, "quicksort", None],
392-
[[[3, 2, 1], [5, 6, 7]], 1, "quicksort", None],
393-
[[[3, 2, 1], [5, 6, 7]], -1, "quicksort", None],
394-
[[3, 2, 1], 0, "quicksort", None],
395-
[np.random.randint(0, 100, (40, 40, 40, 40)), 3, "quicksort", None],
396-
[[3, 2, 1], -5, "quicksort", np.exceptions.AxisError],
385+
[], # Empty list
386+
[3, 2, 1], # Simple list
387+
np.random.randint(0, 10, (3, 2, 3, 4, 4)), # Multi-dimensional array
388+
],
389+
)
390+
@pytest.mark.parametrize("axis", [0, -1, None])
391+
@pytest.mark.parametrize(
392+
("kind", "exc"),
393+
[
394+
["quicksort", None],
395+
["mergesort", UserWarning],
396+
["heapsort", UserWarning],
397+
["stable", UserWarning],
397398
],
398399
)
399400
def test_Sort(x, axis, kind, exc):
@@ -402,36 +403,35 @@ def test_Sort(x, axis, kind, exc):
402403
else:
403404
g = SortOp(kind)(pt.as_tensor_variable(x))
404405

405-
cm = (
406-
contextlib.suppress()
407-
if not exc
408-
else pytest.warns(exc)
409-
if isinstance(exc, Warning)
410-
else pytest.raises(exc)
411-
)
406+
cm = contextlib.suppress() if not exc else pytest.warns(exc)
412407

413408
with cm:
414409
compare_numba_and_py([], [g], [])
415410

416411

417412
@pytest.mark.parametrize(
418-
"x, axis, kind, exc",
413+
"x",
414+
[
415+
[], # Empty list
416+
[3, 2, 1], # Simple list
417+
None, # Multi-dimensional array (see below)
418+
],
419+
)
420+
@pytest.mark.parametrize("axis", [0, -1, None])
421+
@pytest.mark.parametrize(
422+
("kind", "exc"),
419423
[
420-
[[3, 2, 1], None, "quicksort", None],
421-
[[], None, "quicksort", None],
422-
[[[3, 2, 1], [5, 6, 7]], None, "quicksort", None],
423-
[[3, 2, 1], None, "heapsort", UserWarning],
424-
[[3, 2, 1], None, "stable", UserWarning],
425-
[[[3, 2, 1], [5, 6, 7]], 0, "quicksort", None],
426-
[[[3, 2, 1], [5, 6, 7]], None, "quicksort", None],
427-
[[[3, 2, 1], [5, 6, 7]], 1, "quicksort", None],
428-
[[[3, 2, 1], [5, 6, 7]], -1, "quicksort", None],
429-
[[3, 2, 1], 0, "quicksort", None],
430-
[np.random.randint(0, 10, (3, 2, 3)), 1, "quicksort", None],
431-
[np.random.randint(0, 10, (3, 2, 3, 4, 4)), 2, "quicksort", None],
424+
["quicksort", None],
425+
["heapsort", None],
426+
["stable", UserWarning],
432427
],
433428
)
434429
def test_ArgSort(x, axis, kind, exc):
430+
if x is None:
431+
x = np.arange(5 * 5 * 5 * 5)
432+
np.random.shuffle(x)
433+
x = np.reshape(x, (5, 5, 5, 5))
434+
435435
if axis:
436436
g = ArgSortOp(kind)(pt.as_tensor_variable(x), axis)
437437
else:

0 commit comments

Comments
 (0)