Skip to content

Commit 65d2651

Browse files
committed
parametrize zeros in searchsorted python scalar test
1 parent 1b9c2ec commit 65d2651

File tree

1 file changed

+9
-13
lines changed

1 file changed

+9
-13
lines changed

dpctl/tests/test_usm_ndarray_searchsorted.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -395,21 +395,17 @@ def test_searchsorted_strided_scalar_needle():
395395
_check(hay_stack, needles, needles_np)
396396

397397

398+
@pytest.mark.parametrize(
399+
"py_zero",
400+
[bool(0), int(0), float(0), complex(0), np.float32(0), ctypes.c_int(0)],
401+
)
398402
@pytest.mark.parametrize("dt", _all_dtypes)
399-
def test_searchsorted_py_scalars(dt):
403+
def test_searchsorted_py_scalars(py_zero, dt):
400404
q = get_queue_or_skip()
401405
skip_if_dtype_not_supported(dt, q)
402406

403407
x = dpt.zeros(10, dtype=dt, sycl_queue=q)
404-
py_zeros = (
405-
bool(0),
406-
int(0),
407-
float(0),
408-
complex(0),
409-
np.float32(0),
410-
ctypes.c_int(0),
411-
)
412-
for sc in py_zeros:
413-
r1 = dpt.searchsorted(x, sc)
414-
assert isinstance(r1, dpt.usm_ndarray)
415-
assert r1.shape == ()
408+
409+
r1 = dpt.searchsorted(x, py_zero)
410+
assert isinstance(r1, dpt.usm_ndarray)
411+
assert r1.shape == ()

0 commit comments

Comments
 (0)