Skip to content

Commit a2fe553

Browse files
committed
added numba backend and testsfor Nonzero
1 parent ce70501 commit a2fe553

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from pytensor.scalar.basic import ScalarType
3434
from pytensor.scalar.math import Softplus
3535
from pytensor.sparse import SparseTensorType
36+
from pytensor.tensor.basic import Nonzero
3637
from pytensor.tensor.blas import BatchedDot
3738
from pytensor.tensor.math import Dot
3839
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
@@ -744,3 +745,22 @@ def ifelse(cond, *args):
744745
return res[0]
745746

746747
return ifelse
748+
749+
750+
@numba_funcify.register(Nonzero)
751+
def numba_funcify_Nonzero(op, node, **kwargs):
752+
a = node.inputs[0]
753+
754+
if a.ndim == 0:
755+
raise ValueError("Nonzero only supports non-scalar arrays.")
756+
757+
@numba_njit
758+
def nonzero(a):
759+
if a.ndim == 1:
760+
indices = np.where(a != 0)[0]
761+
return indices.astype(np.int64)
762+
763+
result_tuple = np.nonzero(a)
764+
return list(result_tuple)
765+
766+
return nonzero

tests/link/numba/test_basic.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from pytensor.raise_op import assert_op
3232
from pytensor.scalar.basic import ScalarOp, as_scalar
3333
from pytensor.tensor import blas
34+
from pytensor.tensor.basic import Nonzero
3435
from pytensor.tensor.elemwise import Elemwise
3536
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
3637

@@ -893,3 +894,25 @@ def test_function_overhead(mode, benchmark):
893894
assert np.sum(fn(test_x)) == 1000
894895

895896
benchmark(fn, test_x)
897+
898+
899+
@pytest.mark.parametrize(
900+
"input_data",
901+
[np.array([1, 0, 3]), np.array([[0, 1], [2, 0]]), np.array([[0, 0], [0, 0]])],
902+
)
903+
def test_Nonzero(input_data):
904+
if input_data.ndim == 0:
905+
a = pt.scalar("a")
906+
elif input_data.ndim == 1:
907+
a = pt.vector("a")
908+
elif input_data.ndim == 2:
909+
a = pt.matrix("a")
910+
911+
nonzero_op = Nonzero()
912+
913+
node = nonzero_op.make_node(a)
914+
graph_outputs = node.outputs
915+
916+
compare_numba_and_py(
917+
graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data]
918+
)

0 commit comments

Comments
 (0)