11from typing import Literal , Union
22
33import dpctl
4+ import dpctl .tensor as dpt
45import dpctl .utils as du
56
67from ._copy_utils import _empty_like_orderK
78from ._ctors import empty
9+ from ._scalar_utils import _get_dtype , _get_queue_usm_type , _validate_dtype
810from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy
911from ._tensor_impl import _take as ti_take
1012from ._tensor_impl import (
1113 default_device_index_type as ti_default_device_index_type ,
1214)
1315from ._tensor_sorting_impl import _searchsorted_left , _searchsorted_right
14- from ._type_utils import isdtype , result_type
16+ from ._type_utils import (
17+ _resolve_weak_types_all_py_ints ,
18+ _to_device_supported_dtype ,
19+ isdtype ,
20+ )
1521from ._usmarray import usm_ndarray
1622
1723
1824def searchsorted (
1925 x1 : usm_ndarray ,
20- x2 : usm_ndarray ,
26+ x2 : Union [ usm_ndarray , int , float , complex , bool ] ,
2127 / ,
2228 * ,
2329 side : Literal ["left" , "right" ] = "left" ,
@@ -34,8 +40,8 @@ def searchsorted(
3440 input array. Must be a one-dimensional array. If `sorter` is
3541 `None`, must be sorted in ascending order; otherwise, `sorter` must
3642 be an array of indices that sort `x1` in ascending order.
37- x2 (usm_ndarray):
38- array containing search values.
43+ x2 (Union[ usm_ndarray, bool, int, float, complex] ):
44+ search value or values.
3945 side (Literal["left", "right]):
4046 argument controlling which index is returned if a value lands
4147 exactly on an edge. If `x2` is an array of rank `N` where
@@ -56,8 +62,6 @@ def searchsorted(
5662 """
5763 if not isinstance (x1 , usm_ndarray ):
5864 raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x1 )} " )
59- if not isinstance (x2 , usm_ndarray ):
60- raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x2 )} " )
6165 if sorter is not None and not isinstance (sorter , usm_ndarray ):
6266 raise TypeError (
6367 f"Expected dpctl.tensor.usm_ndarray, got { type (sorter )} "
@@ -69,23 +73,39 @@ def searchsorted(
6973 "Expected either 'left' or 'right'"
7074 )
7175
72- if sorter is None :
73- q = du .get_execution_queue ([x1 .sycl_queue , x2 .sycl_queue ])
74- else :
75- q = du .get_execution_queue (
76- [x1 .sycl_queue , x2 .sycl_queue , sorter .sycl_queue ]
77- )
76+ q1 , x1_usm_type = x1 .sycl_queue , x1 .usm_type
77+ q2 , x2_usm_type = _get_queue_usm_type (x2 )
78+ q3 = sorter .sycl_queue if sorter is not None else None
79+ q = du .get_execution_queue (tuple (q for q in (q1 , q2 , q3 ) if q is not None ))
7880 if q is None :
7981 raise du .ExecutionPlacementError (
8082 "Execution placement can not be unambiguously "
8183 "inferred from input arguments."
8284 )
8385
86+ res_usm_type = du .get_coerced_usm_type (
87+ tuple (
88+ ut
89+ for ut in (
90+ x1_usm_type ,
91+ x2_usm_type ,
92+ )
93+ if ut is not None
94+ )
95+ )
96+ du .validate_usm_type (res_usm_type , allow_none = False )
97+ sycl_dev = q .sycl_device
98+
8499 if x1 .ndim != 1 :
85100 raise ValueError ("First argument array must be one-dimensional" )
86101
87102 x1_dt = x1 .dtype
88- x2_dt = x2 .dtype
103+ x2_dt = _get_dtype (x2 , sycl_dev )
104+ if not _validate_dtype (x2_dt ):
105+ raise ValueError (
106+ "dpt.searchsorted search value argument has "
107+ f"unsupported data type { x2_dt } "
108+ )
89109
90110 _manager = du .SequentialOrderManager [q ]
91111 dep_evs = _manager .submitted_events
@@ -100,7 +120,7 @@ def searchsorted(
100120 "Sorter array must be one-dimension with the same "
101121 "shape as the first argument array"
102122 )
103- res = empty (x1 .shape , dtype = x1_dt , usm_type = x1 . usm_type , sycl_queue = q )
123+ res = empty (x1 .shape , dtype = x1_dt , usm_type = x1_usm_type , sycl_queue = q )
104124 ind = (sorter ,)
105125 axis = 0
106126 wrap_out_of_bound_indices_mode = 0
@@ -116,29 +136,28 @@ def searchsorted(
116136 x1 = res
117137 _manager .add_event_pair (ht_ev , ev )
118138
119- if x1_dt != x2_dt :
120- dt = result_type (x1 , x2 )
121- if x1_dt != dt :
122- x1_buf = _empty_like_orderK ( x1 , dt )
123- dep_evs = _manager . submitted_events
124- ht_ev , ev = ti_copy (
125- src = x1 , dst = x1_buf , sycl_queue = q , depends = dep_evs
126- )
127- _manager . add_event_pair ( ht_ev , ev )
128- x1 = x1_buf
129- if x2_dt != dt :
130- x2_buf = _empty_like_orderK (x2 , dt )
131- dep_evs = _manager . submitted_events
132- ht_ev , ev = ti_copy (
133- src = x2 , dst = x2_buf , sycl_queue = q , depends = dep_evs
134- )
135- _manager .add_event_pair (ht_ev , ev )
136- x2 = x2_buf
139+ dt1 , dt2 = _resolve_weak_types_all_py_ints ( x1_dt , x2_dt , sycl_dev )
140+ dt = _to_device_supported_dtype ( dpt . result_type (dt1 , dt2 ), sycl_dev )
141+
142+ if x1_dt != dt :
143+ x1_buf = _empty_like_orderK ( x1 , dt )
144+ dep_evs = _manager . submitted_events
145+ ht_ev , ev = ti_copy ( src = x1 , dst = x1_buf , sycl_queue = q , depends = dep_evs )
146+ _manager . add_event_pair ( ht_ev , ev )
147+ x1 = x1_buf
148+
149+ if not isinstance ( x2 , usm_ndarray ) :
150+ x2 = dpt . asarray (x2 , dtype = dt2 , usm_type = res_usm_type , sycl_queue = q )
151+ if x2 . dtype != dt :
152+ x2_buf = _empty_like_orderK ( x2 , dt )
153+ dep_evs = _manager . submitted_events
154+ ht_ev , ev = ti_copy ( src = x2 , dst = x2_buf , sycl_queue = q , depends = dep_evs )
155+ _manager .add_event_pair (ht_ev , ev )
156+ x2 = x2_buf
137157
138- dst_usm_type = du .get_coerced_usm_type ([x1 .usm_type , x2 .usm_type ])
139158 index_dt = ti_default_device_index_type (q )
140159
141- dst = _empty_like_orderK (x2 , index_dt , usm_type = dst_usm_type )
160+ dst = _empty_like_orderK (x2 , index_dt , usm_type = res_usm_type )
142161
143162 dep_evs = _manager .submitted_events
144163 if side == "left" :
0 commit comments