Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pytensor.scalar.basic import ScalarType
from pytensor.scalar.math import Softplus
from pytensor.sparse import SparseTensorType
from pytensor.tensor.basic import Nonzero
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
Expand Down Expand Up @@ -744,3 +745,22 @@ def ifelse(cond, *args):
return res[0]

return ifelse


@numba_funcify.register(Nonzero)
def numba_funcify_Nonzero(op, node, **kwargs):
a = node.inputs[0]

if a.ndim == 0:
raise ValueError("Nonzero only supports non-scalar arrays.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This input validation is done by the nonzero Op itself, there's no need to repeat it here.


@numba_njit
def nonzero(a):
if a.ndim == 1:
indices = np.where(a != 0)[0]
return indices.astype(np.int64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ndim == 1 is not a special case in the C-backend. In that case, you get a 1-tuple. All backends should return the same thing.

Copy link
Contributor Author

@Abhinav-Khot Abhinav-Khot Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The modification returns also 1-tuple because this is already implemented in the nonzero function itself

class Nonzero(Op):
    """
    Return the indices of the elements that are non-zero.

    Parameters
    ----------
    a: array_like
        Input array.

    Returns
    -------
    indices: list
        A list containing the indices of the non-zero elements of `a`.

    See Also
    --------
    nonzero_values : Return the non-zero elements of the input array
    flatnonzero : Return the indices of the non-zero elements of the
        flattened input array.

    """

    __props__ = ()

    def make_node(self, a):
        a = as_tensor_variable(a)
        if a.ndim == 0:
            raise ValueError("Nonzero only supports non-scalar arrays.")
        output = [TensorType(dtype="int64", shape=(None,))() for i in range(a.ndim)]
        return Apply(self, [a], output)

    def perform(self, node, inp, out_):
        a = inp[0]

        result_tuple = np.nonzero(a)
        for i, res in enumerate(result_tuple):
            out_[i][0] = res.astype("int64")

    def grad(self, inp, grads):
        return [grad_undefined(self, 0, inp[0])]


_nonzero = Nonzero()


def nonzero(a, return_matrix=False):
    """
    Returns one of the following:

        If return_matrix is False (default, same as NumPy):
            A tuple of vector arrays such that the ith element of the jth array
            is the index of the ith non-zero element of the input array in the
            jth dimension.

        If return_matrix is True (same as PyTensor Op):
            Returns a matrix of shape (ndim, number of nonzero elements) such
            that element (i,j) is the index in the ith dimension of the jth
            non-zero element.

    Parameters
    ----------
    a : array_like
        Input array.
    return_matrix : bool
        If True, returns a symbolic matrix. If False, returns a tuple of
        arrays. Defaults to False.

    Returns
    -------
    tuple of vectors or matrix

    See Also
    --------
    nonzero_values : Return the non-zero elements of the input array
    flatnonzero : Return the indices of the non-zero elements of the
        flattened input array.

    """
    res = _nonzero(a)
    if isinstance(res, list):
        res = tuple(res)
    else:
        res = (res,)

    if return_matrix:
        if len(res) > 1:
            return stack(res, 0)
        elif len(res) == 1:
            return shape_padleft(res[0])
    else:
        return res

If we do not handle the ndim = 1 case seperately, for the array [1,2,0] the result would be (([2],),) which is a tuple of tuple of lists. Instead it should be a tuple of lists which is what we get with the modification. I have replaced this modification with :

if(a.ndim == 1):
       return result_tuple[0]

for efficiency.


result_tuple = np.nonzero(a)
return list(result_tuple)

return nonzero
23 changes: 23 additions & 0 deletions tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pytensor.raise_op import assert_op
from pytensor.scalar.basic import ScalarOp, as_scalar
from pytensor.tensor import blas
from pytensor.tensor.basic import Nonzero
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape

Expand Down Expand Up @@ -893,3 +894,25 @@ def test_function_overhead(mode, benchmark):
assert np.sum(fn(test_x)) == 1000

benchmark(fn, test_x)


@pytest.mark.parametrize(
"input_data",
[np.array([1, 0, 3]), np.array([[0, 1], [2, 0]]), np.array([[0, 0], [0, 0]])],
)
def test_Nonzero(input_data):
if input_data.ndim == 0:
a = pt.scalar("a")
elif input_data.ndim == 1:
a = pt.vector("a")
elif input_data.ndim == 2:
a = pt.matrix("a")

nonzero_op = Nonzero()

node = nonzero_op.make_node(a)
graph_outputs = node.outputs

compare_numba_and_py(
graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data]
)