diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 7685c17d9c..7ff7f48967 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -33,6 +33,7 @@ from pytensor.scalar.basic import ScalarType from pytensor.scalar.math import Softplus from pytensor.sparse import SparseTensorType +from pytensor.tensor.basic import Nonzero from pytensor.tensor.blas import BatchedDot from pytensor.tensor.math import Dot from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape @@ -744,3 +745,15 @@ def ifelse(cond, *args): return res[0] return ifelse + + +@numba_funcify.register(Nonzero) +def numba_funcify_Nonzero(op, node, **kwargs): + @numba_njit + def nonzero(a): + result_tuple = np.nonzero(a) + if a.ndim == 1: + return result_tuple[0] + return list(result_tuple) + + return nonzero diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 654cbe7bd4..1346da55e7 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -287,7 +287,6 @@ def assert_fn(x, y): ) test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs numba_res = pytensor_numba_fn(*test_inputs_copy) - if isinstance(graph_outputs, tuple | list): for j, p in zip(numba_res, py_res, strict=True): assert_fn(j, p) @@ -893,3 +892,17 @@ def test_function_overhead(mode, benchmark): assert np.sum(fn(test_x)) == 1000 benchmark(fn, test_x) + + +@pytest.mark.parametrize( + "input_data", + [np.array([1, 0, 3]), np.array([[0, 1], [2, 0]]), np.array([[0, 0], [0, 0]])], +) +def test_Nonzero(input_data): + a = pt.tensor("a", shape=(None,) * input_data.ndim) + + graph_outputs = pt.nonzero(a) + + compare_numba_and_py( + graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data] + )