1+ # Data Parallel Control (dpctl)
2+ #
3+ # Copyright 2020-2025 Intel Corporation
4+ #
5+ # Licensed under the Apache License, Version 2.0 (the "License");
6+ # you may not use this file except in compliance with the License.
7+ # You may obtain a copy of the License at
8+ #
9+ # http://www.apache.org/licenses/LICENSE-2.0
10+ #
11+ # Unless required by applicable law or agreed to in writing, software
12+ # distributed under the License is distributed on an "AS IS" BASIS,
13+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ # See the License for the specific language governing permissions and
15+ # limitations under the License.
16+
17+ import ctypes
18+
119import numpy as np
220import pytest
321
725
826from .helper import get_queue_or_skip , skip_if_dtype_not_supported
927
28+ _integer_dtypes = [
29+ "i1" ,
30+ "u1" ,
31+ "i2" ,
32+ "u2" ,
33+ "i4" ,
34+ "u4" ,
35+ "i8" ,
36+ "u8" ,
37+ ]
38+
39+ _floating_dtypes = [
40+ "f2" ,
41+ "f4" ,
42+ "f8" ,
43+ ]
44+
45+ _complex_dtypes = [
46+ "c8" ,
47+ "c16" ,
48+ ]
49+
50+ _all_dtypes = ["?" ] + _integer_dtypes + _floating_dtypes + _complex_dtypes
51+
1052
1153def _check (hay_stack , needles , needles_np ):
1254 assert hay_stack .dtype == needles .dtype
@@ -73,19 +115,7 @@ def test_searchsorted_strided_bool():
73115 )
74116
75117
76- @pytest .mark .parametrize (
77- "idt" ,
78- [
79- dpt .int8 ,
80- dpt .uint8 ,
81- dpt .int16 ,
82- dpt .uint16 ,
83- dpt .int32 ,
84- dpt .uint32 ,
85- dpt .int64 ,
86- dpt .uint64 ,
87- ],
88- )
118+ @pytest .mark .parametrize ("idt" , _integer_dtypes )
89119def test_searchsorted_contig_int (idt ):
90120 q = get_queue_or_skip ()
91121 skip_if_dtype_not_supported (idt , q )
@@ -105,19 +135,7 @@ def test_searchsorted_contig_int(idt):
105135 )
106136
107137
108- @pytest .mark .parametrize (
109- "idt" ,
110- [
111- dpt .int8 ,
112- dpt .uint8 ,
113- dpt .int16 ,
114- dpt .uint16 ,
115- dpt .int32 ,
116- dpt .uint32 ,
117- dpt .int64 ,
118- dpt .uint64 ,
119- ],
120- )
138+ @pytest .mark .parametrize ("idt" , _integer_dtypes )
121139def test_searchsorted_strided_int (idt ):
122140 q = get_queue_or_skip ()
123141 skip_if_dtype_not_supported (idt , q )
@@ -144,12 +162,12 @@ def _add_extended_fp(array):
144162 array [- 1 ] = dpt .nan
145163
146164
147- @pytest .mark .parametrize ("idt " , [ dpt . float16 , dpt . float32 , dpt . float64 ] )
148- def test_searchsorted_contig_fp (idt ):
165+ @pytest .mark .parametrize ("fdt " , _floating_dtypes )
166+ def test_searchsorted_contig_fp (fdt ):
149167 q = get_queue_or_skip ()
150- skip_if_dtype_not_supported (idt , q )
168+ skip_if_dtype_not_supported (fdt , q )
151169
152- dt = dpt .dtype (idt )
170+ dt = dpt .dtype (fdt )
153171
154172 hay_stack = dpt .linspace (0 , 1 , num = 255 , dtype = dt , endpoint = True )
155173 _add_extended_fp (hay_stack )
@@ -165,12 +183,12 @@ def test_searchsorted_contig_fp(idt):
165183 )
166184
167185
168- @pytest .mark .parametrize ("idt " , [ dpt . float16 , dpt . float32 , dpt . float64 ] )
169- def test_searchsorted_strided_fp (idt ):
186+ @pytest .mark .parametrize ("fdt " , _floating_dtypes )
187+ def test_searchsorted_strided_fp (fdt ):
170188 q = get_queue_or_skip ()
171- skip_if_dtype_not_supported (idt , q )
189+ skip_if_dtype_not_supported (fdt , q )
172190
173- dt = dpt .dtype (idt )
191+ dt = dpt .dtype (fdt )
174192
175193 hay_stack = dpt .repeat (
176194 dpt .linspace (0 , 1 , num = 255 , dtype = dt , endpoint = True ), 4
@@ -213,12 +231,12 @@ def _add_extended_cfp(array):
213231 return dpt .sort (dpt .concat ((ev , array )))
214232
215233
216- @pytest .mark .parametrize ("idt " , [ dpt . complex64 , dpt . complex128 ] )
217- def test_searchsorted_contig_cfp (idt ):
234+ @pytest .mark .parametrize ("cdt " , _complex_dtypes )
235+ def test_searchsorted_contig_cfp (cdt ):
218236 q = get_queue_or_skip ()
219- skip_if_dtype_not_supported (idt , q )
237+ skip_if_dtype_not_supported (cdt , q )
220238
221- dt = dpt .dtype (idt )
239+ dt = dpt .dtype (cdt )
222240
223241 hay_stack = dpt .linspace (0 , 1 , num = 255 , dtype = dt , endpoint = True )
224242 hay_stack = _add_extended_cfp (hay_stack )
@@ -233,12 +251,12 @@ def test_searchsorted_contig_cfp(idt):
233251 )
234252
235253
236- @pytest .mark .parametrize ("idt " , [ dpt . complex64 , dpt . complex128 ] )
237- def test_searchsorted_strided_cfp (idt ):
254+ @pytest .mark .parametrize ("cdt " , _complex_dtypes )
255+ def test_searchsorted_strided_cfp (cdt ):
238256 q = get_queue_or_skip ()
239- skip_if_dtype_not_supported (idt , q )
257+ skip_if_dtype_not_supported (cdt , q )
240258
241- dt = dpt .dtype (idt )
259+ dt = dpt .dtype (cdt )
242260
243261 hay_stack = dpt .repeat (
244262 dpt .linspace (0 , 1 , num = 255 , dtype = dt , endpoint = True ), 4
@@ -375,3 +393,23 @@ def test_searchsorted_strided_scalar_needle():
375393 needles = dpt .asarray (needles_np )
376394
377395 _check (hay_stack , needles , needles_np )
396+
397+
398+ @pytest .mark .parametrize ("dt" , _all_dtypes )
399+ def test_searchsorted_py_scalars (dt ):
400+ q = get_queue_or_skip ()
401+ skip_if_dtype_not_supported (dt , q )
402+
403+ x = dpt .zeros (10 , dtype = dt , sycl_queue = q )
404+ py_zeros = (
405+ bool (0 ),
406+ int (0 ),
407+ float (0 ),
408+ complex (0 ),
409+ np .float32 (0 ),
410+ ctypes .c_int (0 ),
411+ )
412+ for sc in py_zeros :
413+ r1 = dpt .searchsorted (x , sc )
414+ assert isinstance (r1 , dpt .usm_ndarray )
415+ assert r1 .shape == ()
0 commit comments