Skip to content

Commit d0159e5

Browse files
author
Victor Garcia Reolid
committed
feat: support axis
Signed-off-by: Victor Garcia Reolid <[email protected]>
1 parent 5aa1a39 commit d0159e5

File tree

2 files changed

+69
-31
lines changed

2 files changed

+69
-31
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import scipy
1212
import scipy.special
1313
from llvmlite import ir
14-
from numba import types
14+
from numba import prange, types
1515
from numba.core.errors import NumbaWarning, TypingError
1616
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
1717
from numba.extending import box, overload
@@ -437,7 +437,14 @@ def shape_i(x):
437437
def numba_funcify_SortOp(op, node, **kwargs):
438438
@numba_njit
439439
def sort_f(a, axis):
440-
return np.sort(a) # numba supports sort without arguments
440+
if not isinstance(axis, int):
441+
axis = -1
442+
443+
a_swapped = np.swapaxes(a, axis, -1)
444+
a_sorted = np.sort(a_swapped)
445+
a_sorted_swapped = np.swapaxes(a_sorted, -1, axis)
446+
447+
return a_sorted_swapped
441448

442449
if op.kind != "quicksort":
443450
warnings.warn(
@@ -455,10 +462,27 @@ def sort_f(a, axis):
455462
def numba_funcify_ArgSortOp(op, node, **kwargs):
456463
def argsort_f_kind(kind):
457464
@numba_njit
458-
def argsort_f(a, axis):
459-
return np.argsort(a, kind=kind)
465+
def argort_vec(X, axis):
466+
if axis > len(X.shape):
467+
raise ValueError("Wrong axis.")
468+
469+
axis = axis.item()
470+
471+
Y = np.swapaxes(X, axis, 0)
472+
result = np.empty_like(Y)
473+
474+
N = int(np.prod(np.array(Y.shape)[1:]))
475+
indices = list(np.ndindex(Y.shape[1:]))
476+
477+
for i in prange(N):
478+
idx = indices[i]
479+
result[:, *idx] = np.argsort(Y[:, *idx], kind=kind)
480+
481+
result = np.swapaxes(result, 0, axis)
482+
483+
return result
460484

461-
return argsort_f
485+
return argort_vec
462486

463487
kind = op.kind
464488

tests/link/numba/test_basic.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -380,45 +380,59 @@ def test_Shape(x, i):
380380

381381

382382
@pytest.mark.parametrize(
383-
"kind, exc",
383+
"x, axis, kind, exc",
384384
[
385-
["quicksort", None],
386-
["mergesort", UserWarning],
387-
["heapsort", UserWarning],
388-
["stable", UserWarning],
385+
[[3, 2, 1], None, "quicksort", None],
386+
[[], None, "quicksort", None],
387+
[[[3, 2, 1], [5, 6, 7]], None, "quicksort", None],
388+
[[3, 2, 1], None, "mergesort", UserWarning],
389+
[[3, 2, 1], None, "heapsort", UserWarning],
390+
[[3, 2, 1], None, "stable", UserWarning],
391+
[[[3, 2, 1], [5, 6, 7]], 0, "quicksort", None],
392+
[[[3, 2, 1], [5, 6, 7]], 1, "quicksort", None],
393+
[[[3, 2, 1], [5, 6, 7]], -1, "quicksort", None],
394+
[[3, 2, 1], 0, "quicksort", None],
395+
[np.random.randint(0, 100, (40, 40, 40, 40)), 3, "quicksort", None],
389396
],
390397
)
391-
def test_Sort(kind, exc):
392-
x = [5, 4, 3, 2, 1]
398+
def test_Sort(x, axis, kind, exc):
399+
if axis:
400+
g = SortOp(kind)(pt.as_tensor_variable(x), axis)
401+
else:
402+
g = SortOp(kind)(pt.as_tensor_variable(x))
393403

394-
g = SortOp(kind)(pt.as_tensor_variable(x))
404+
cm = contextlib.suppress() if not exc else pytest.warns(exc)
395405

396-
if exc:
397-
with pytest.warns(exc):
398-
compare_numba_and_py([], [g], [])
399-
else:
406+
with cm:
400407
compare_numba_and_py([], [g], [])
401408

402-
compare_numba_and_py([], [g], [])
403-
404409

405410
@pytest.mark.parametrize(
406-
"kind, exc",
411+
"x, axis, kind, exc",
407412
[
408-
["quicksort", None],
409-
["mergesort", None],
410-
["heapsort", UserWarning],
411-
["stable", UserWarning],
413+
[[3, 2, 1], None, "quicksort", None],
414+
[[], None, "quicksort", None],
415+
[[[3, 2, 1], [5, 6, 7]], None, "quicksort", None],
416+
[[3, 2, 1], None, "heapsort", UserWarning],
417+
[[3, 2, 1], None, "stable", UserWarning],
418+
[[[3, 2, 1], [5, 6, 7]], 0, "quicksort", None],
419+
[[[3, 2, 1], [5, 6, 7]], None, "quicksort", None],
420+
[[[3, 2, 1], [5, 6, 7]], 1, "quicksort", None],
421+
[[[3, 2, 1], [5, 6, 7]], -1, "quicksort", None],
422+
[[3, 2, 1], 0, "quicksort", None],
423+
[np.random.randint(0, 10, (3, 2, 3)), 1, "quicksort", None],
424+
[np.random.randint(0, 10, (3, 2, 3, 4, 4)), 2, "quicksort", None],
412425
],
413426
)
414-
def test_ArgSort(kind, exc):
415-
x = [5, 4, 3, 2, 1]
416-
g = ArgSortOp(kind)(pt.as_tensor_variable(x))
417-
418-
if exc:
419-
with pytest.warns(exc):
420-
compare_numba_and_py([], [g], [])
427+
def test_ArgSort(x, axis, kind, exc):
428+
if axis:
429+
g = ArgSortOp(kind)(pt.as_tensor_variable(x), axis)
421430
else:
431+
g = ArgSortOp(kind)(pt.as_tensor_variable(x))
432+
433+
cm = contextlib.suppress() if not exc else pytest.warns(exc)
434+
435+
with cm:
422436
compare_numba_and_py([], [g], [])
423437

424438

0 commit comments

Comments
 (0)