Skip to content

Commit b7c56e1

Browse files
committed
add test for python scalar second argument for searchsorted
1 parent fff27a8 commit b7c56e1

File tree

1 file changed

+80
-42
lines changed

1 file changed

+80
-42
lines changed

dpctl/tests/test_usm_ndarray_searchsorted.py

Lines changed: 80 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import ctypes
18+
119
import numpy as np
220
import pytest
321

@@ -7,6 +25,30 @@
725

826
from .helper import get_queue_or_skip, skip_if_dtype_not_supported
927

28+
_integer_dtypes = [
29+
"i1",
30+
"u1",
31+
"i2",
32+
"u2",
33+
"i4",
34+
"u4",
35+
"i8",
36+
"u8",
37+
]
38+
39+
_floating_dtypes = [
40+
"f2",
41+
"f4",
42+
"f8",
43+
]
44+
45+
_complex_dtypes = [
46+
"c8",
47+
"c16",
48+
]
49+
50+
_all_dtypes = ["?"] + _integer_dtypes + _floating_dtypes + _complex_dtypes
51+
1052

1153
def _check(hay_stack, needles, needles_np):
1254
assert hay_stack.dtype == needles.dtype
@@ -73,19 +115,7 @@ def test_searchsorted_strided_bool():
73115
)
74116

75117

76-
@pytest.mark.parametrize(
77-
"idt",
78-
[
79-
dpt.int8,
80-
dpt.uint8,
81-
dpt.int16,
82-
dpt.uint16,
83-
dpt.int32,
84-
dpt.uint32,
85-
dpt.int64,
86-
dpt.uint64,
87-
],
88-
)
118+
@pytest.mark.parametrize("idt", _integer_dtypes)
89119
def test_searchsorted_contig_int(idt):
90120
q = get_queue_or_skip()
91121
skip_if_dtype_not_supported(idt, q)
@@ -105,19 +135,7 @@ def test_searchsorted_contig_int(idt):
105135
)
106136

107137

108-
@pytest.mark.parametrize(
109-
"idt",
110-
[
111-
dpt.int8,
112-
dpt.uint8,
113-
dpt.int16,
114-
dpt.uint16,
115-
dpt.int32,
116-
dpt.uint32,
117-
dpt.int64,
118-
dpt.uint64,
119-
],
120-
)
138+
@pytest.mark.parametrize("idt", _integer_dtypes)
121139
def test_searchsorted_strided_int(idt):
122140
q = get_queue_or_skip()
123141
skip_if_dtype_not_supported(idt, q)
@@ -144,12 +162,12 @@ def _add_extended_fp(array):
144162
array[-1] = dpt.nan
145163

146164

147-
@pytest.mark.parametrize("idt", [dpt.float16, dpt.float32, dpt.float64])
148-
def test_searchsorted_contig_fp(idt):
165+
@pytest.mark.parametrize("fdt", _floating_dtypes)
166+
def test_searchsorted_contig_fp(fdt):
149167
q = get_queue_or_skip()
150-
skip_if_dtype_not_supported(idt, q)
168+
skip_if_dtype_not_supported(fdt, q)
151169

152-
dt = dpt.dtype(idt)
170+
dt = dpt.dtype(fdt)
153171

154172
hay_stack = dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True)
155173
_add_extended_fp(hay_stack)
@@ -165,12 +183,12 @@ def test_searchsorted_contig_fp(idt):
165183
)
166184

167185

168-
@pytest.mark.parametrize("idt", [dpt.float16, dpt.float32, dpt.float64])
169-
def test_searchsorted_strided_fp(idt):
186+
@pytest.mark.parametrize("fdt", _floating_dtypes)
187+
def test_searchsorted_strided_fp(fdt):
170188
q = get_queue_or_skip()
171-
skip_if_dtype_not_supported(idt, q)
189+
skip_if_dtype_not_supported(fdt, q)
172190

173-
dt = dpt.dtype(idt)
191+
dt = dpt.dtype(fdt)
174192

175193
hay_stack = dpt.repeat(
176194
dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True), 4
@@ -213,12 +231,12 @@ def _add_extended_cfp(array):
213231
return dpt.sort(dpt.concat((ev, array)))
214232

215233

216-
@pytest.mark.parametrize("idt", [dpt.complex64, dpt.complex128])
217-
def test_searchsorted_contig_cfp(idt):
234+
@pytest.mark.parametrize("cdt", _complex_dtypes)
235+
def test_searchsorted_contig_cfp(cdt):
218236
q = get_queue_or_skip()
219-
skip_if_dtype_not_supported(idt, q)
237+
skip_if_dtype_not_supported(cdt, q)
220238

221-
dt = dpt.dtype(idt)
239+
dt = dpt.dtype(cdt)
222240

223241
hay_stack = dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True)
224242
hay_stack = _add_extended_cfp(hay_stack)
@@ -233,12 +251,12 @@ def test_searchsorted_contig_cfp(idt):
233251
)
234252

235253

236-
@pytest.mark.parametrize("idt", [dpt.complex64, dpt.complex128])
237-
def test_searchsorted_strided_cfp(idt):
254+
@pytest.mark.parametrize("cdt", _complex_dtypes)
255+
def test_searchsorted_strided_cfp(cdt):
238256
q = get_queue_or_skip()
239-
skip_if_dtype_not_supported(idt, q)
257+
skip_if_dtype_not_supported(cdt, q)
240258

241-
dt = dpt.dtype(idt)
259+
dt = dpt.dtype(cdt)
242260

243261
hay_stack = dpt.repeat(
244262
dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True), 4
@@ -375,3 +393,23 @@ def test_searchsorted_strided_scalar_needle():
375393
needles = dpt.asarray(needles_np)
376394

377395
_check(hay_stack, needles, needles_np)
396+
397+
398+
@pytest.mark.parametrize("dt", _all_dtypes)
399+
def test_searchsorted_py_scalars(dt):
400+
q = get_queue_or_skip()
401+
skip_if_dtype_not_supported(dt, q)
402+
403+
x = dpt.zeros(10, dtype=dt, sycl_queue=q)
404+
py_zeros = (
405+
bool(0),
406+
int(0),
407+
float(0),
408+
complex(0),
409+
np.float32(0),
410+
ctypes.c_int(0),
411+
)
412+
for sc in py_zeros:
413+
r1 = dpt.searchsorted(x, sc)
414+
assert isinstance(r1, dpt.usm_ndarray)
415+
assert r1.shape == ()

0 commit comments

Comments
 (0)