Skip to content

Commit 0656684

Browse files
author
Victor Garcia Reolid
committed
feat: support numba compiled sort and argsort functions
Signed-off-by: Victor Garcia Reolid <[email protected]>
1 parent a149f6c commit 0656684

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from pytensor.tensor.math import Dot
3838
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
3939
from pytensor.tensor.slinalg import Solve
40+
from pytensor.tensor.sort import ArgSortOp, SortOp
4041
from pytensor.tensor.type import TensorType
4142
from pytensor.tensor.type_other import MakeSlice, NoneConst
4243

@@ -432,6 +433,46 @@ def shape_i(x):
432433
return shape_i
433434

434435

436+
@numba_funcify.register(SortOp)
437+
def numba_funcify_SortOp(op, node, **kwargs):
438+
if op.kind == "quicksort":
439+
440+
@numba_njit
441+
def sort_f(a, axis):
442+
return np.sort(a) # numba supports sort without arguments
443+
else:
444+
ret_sig = get_numba_type(node.outputs[0].type)
445+
446+
def sort_f(a, axis):
447+
with numba.objmode(ret=ret_sig):
448+
ret = np.sort(a, axis=axis, kind=op.kind)
449+
return ret
450+
451+
return sort_f
452+
453+
454+
@numba_funcify.register(ArgSortOp)
455+
def numba_funcify_ArgSortOp(op, node, **kwargs):
456+
def argsort_f_kind(kind):
457+
@numba_njit
458+
def argsort_f(a, axis):
459+
return np.argsort(a, kind=kind)
460+
461+
return argsort_f
462+
463+
if op.kind in ["quicksort", "mergesort"]:
464+
return argsort_f_kind(op.kind)
465+
else:
466+
ret_sig = get_numba_type(node.outputs[0].type)
467+
468+
def argsort_f(a, axis):
469+
with numba.objmode(ret=ret_sig):
470+
ret = np.argsort(a, axis=axis, kind=op.kind)
471+
return ret
472+
473+
return argsort_f
474+
475+
435476
@numba.extending.intrinsic
436477
def direct_cast(typingctx, val, typ):
437478
if isinstance(typ, numba.types.TypeRef):

tests/link/numba/test_basic.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from pytensor.tensor import blas
3434
from pytensor.tensor.elemwise import Elemwise
3535
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
36+
from pytensor.tensor.sort import ArgSortOp, SortOp
3637

3738

3839
if TYPE_CHECKING:
@@ -378,6 +379,20 @@ def test_Shape(x, i):
378379
compare_numba_and_py([], [g], [])
379380

380381

382+
@pytest.mark.parametrize("kind", ["quicksort"])
383+
def test_Sort(kind):
384+
x = [5, 4, 3, 2, 1]
385+
g = SortOp(kind)(pt.as_tensor_variable(x))
386+
compare_numba_and_py([], [g], [])
387+
388+
389+
@pytest.mark.parametrize("kind", ["quicksort", "mergesort"])
390+
def test_ArgSort(kind):
391+
x = [5, 4, 3, 2, 1]
392+
g = ArgSortOp(kind)(pt.as_tensor_variable(x))
393+
compare_numba_and_py([], [g], [])
394+
395+
381396
@pytest.mark.parametrize(
382397
"v, shape, ndim",
383398
[

0 commit comments

Comments
 (0)