Skip to content

Commit b336df9

Browse files
committed
Modified the tests and the dispatch for efficiency
1 parent e64d6d3 commit b336df9

File tree

2 files changed

+4
-21
lines changed

2 files changed

+4
-21
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -749,18 +749,11 @@ def ifelse(cond, *args):
749749

750750
@numba_funcify.register(Nonzero)
751751
def numba_funcify_Nonzero(op, node, **kwargs):
752-
a = node.inputs[0]
753-
754-
if a.ndim == 0:
755-
raise ValueError("Nonzero only supports non-scalar arrays.")
756-
757752
@numba_njit
758753
def nonzero(a):
759-
if a.ndim == 1:
760-
indices = np.where(a != 0)[0]
761-
return indices.astype(np.int64)
762-
763754
result_tuple = np.nonzero(a)
755+
if a.ndim == 1:
756+
return result_tuple[0]
764757
return list(result_tuple)
765758

766759
return nonzero

tests/link/numba/test_basic.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from pytensor.raise_op import assert_op
3232
from pytensor.scalar.basic import ScalarOp, as_scalar
3333
from pytensor.tensor import blas
34-
from pytensor.tensor.basic import Nonzero
3534
from pytensor.tensor.elemwise import Elemwise
3635
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
3736

@@ -288,7 +287,6 @@ def assert_fn(x, y):
288287
)
289288
test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs
290289
numba_res = pytensor_numba_fn(*test_inputs_copy)
291-
292290
if isinstance(graph_outputs, tuple | list):
293291
for j, p in zip(numba_res, py_res, strict=True):
294292
assert_fn(j, p)
@@ -901,17 +899,9 @@ def test_function_overhead(mode, benchmark):
901899
[np.array([1, 0, 3]), np.array([[0, 1], [2, 0]]), np.array([[0, 0], [0, 0]])],
902900
)
903901
def test_Nonzero(input_data):
904-
if input_data.ndim == 0:
905-
a = pt.scalar("a")
906-
elif input_data.ndim == 1:
907-
a = pt.vector("a")
908-
elif input_data.ndim == 2:
909-
a = pt.matrix("a")
910-
911-
nonzero_op = Nonzero()
902+
a = pt.tensor("a", shape=(None,) * input_data.ndim)
912903

913-
node = nonzero_op.make_node(a)
914-
graph_outputs = node.outputs
904+
graph_outputs = pt.nonzero(a)
915905

916906
compare_numba_and_py(
917907
graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data]

0 commit comments

Comments
 (0)