Skip to content

Commit 5aa1a39

Browse files
author
Victor Garcia Reolid
committed
default to supported kind and add warning
Signed-off-by: Victor Garcia Reolid <[email protected]>
1 parent 0656684 commit 5aa1a39

File tree

2 files changed

+57
-25
lines changed

2 files changed

+57
-25
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -435,18 +435,18 @@ def shape_i(x):
435435

436436
@numba_funcify.register(SortOp)
437437
def numba_funcify_SortOp(op, node, **kwargs):
438-
if op.kind == "quicksort":
439-
440-
@numba_njit
441-
def sort_f(a, axis):
442-
return np.sort(a) # numba supports sort without arguments
443-
else:
444-
ret_sig = get_numba_type(node.outputs[0].type)
438+
@numba_njit
439+
def sort_f(a, axis):
440+
return np.sort(a) # numba supports sort without arguments
445441

446-
def sort_f(a, axis):
447-
with numba.objmode(ret=ret_sig):
448-
ret = np.sort(a, axis=axis, kind=op.kind)
449-
return ret
442+
if op.kind != "quicksort":
443+
warnings.warn(
444+
(
445+
f'Numba function sort doesn\'t support kind="{op.kind}"'
446+
" switching to `quicksort`."
447+
),
448+
UserWarning,
449+
)
450450

451451
return sort_f
452452

@@ -460,17 +460,20 @@ def argsort_f(a, axis):
460460

461461
return argsort_f
462462

463-
if op.kind in ["quicksort", "mergesort"]:
464-
return argsort_f_kind(op.kind)
465-
else:
466-
ret_sig = get_numba_type(node.outputs[0].type)
463+
kind = op.kind
467464

468-
def argsort_f(a, axis):
469-
with numba.objmode(ret=ret_sig):
470-
ret = np.argsort(a, axis=axis, kind=op.kind)
471-
return ret
465+
if kind in ["quicksort", "mergesort"]:
466+
return argsort_f_kind(kind)
467+
else:
468+
warnings.warn(
469+
(
470+
f'Numba function argsort doesn\'t support kind="{op.kind}"'
471+
" switching to `quicksort`."
472+
),
473+
UserWarning,
474+
)
472475

473-
return argsort_f
476+
return argsort_f_kind("quicksort")
474477

475478

476479
@numba.extending.intrinsic

tests/link/numba/test_basic.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -379,18 +379,47 @@ def test_Shape(x, i):
379379
compare_numba_and_py([], [g], [])
380380

381381

382-
@pytest.mark.parametrize("kind", ["quicksort"])
383-
def test_Sort(kind):
382+
@pytest.mark.parametrize(
383+
"kind, exc",
384+
[
385+
["quicksort", None],
386+
["mergesort", UserWarning],
387+
["heapsort", UserWarning],
388+
["stable", UserWarning],
389+
],
390+
)
391+
def test_Sort(kind, exc):
384392
x = [5, 4, 3, 2, 1]
393+
385394
g = SortOp(kind)(pt.as_tensor_variable(x))
395+
396+
if exc:
397+
with pytest.warns(exc):
398+
compare_numba_and_py([], [g], [])
399+
else:
400+
compare_numba_and_py([], [g], [])
401+
386402
compare_numba_and_py([], [g], [])
387403

388404

389-
@pytest.mark.parametrize("kind", ["quicksort", "mergesort"])
390-
def test_ArgSort(kind):
405+
@pytest.mark.parametrize(
406+
"kind, exc",
407+
[
408+
["quicksort", None],
409+
["mergesort", None],
410+
["heapsort", UserWarning],
411+
["stable", UserWarning],
412+
],
413+
)
414+
def test_ArgSort(kind, exc):
391415
x = [5, 4, 3, 2, 1]
392416
g = ArgSortOp(kind)(pt.as_tensor_variable(x))
393-
compare_numba_and_py([], [g], [])
417+
418+
if exc:
419+
with pytest.warns(exc):
420+
compare_numba_and_py([], [g], [])
421+
else:
422+
compare_numba_and_py([], [g], [])
394423

395424

396425
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)