Skip to content

Commit e85940e

Browse files
committed
permit Python scalars as second argument to dpt.searchsorted
1 parent 2e67f59 commit e85940e

File tree

1 file changed

+53
-34
lines changed

1 file changed

+53
-34
lines changed

dpctl/tensor/_searchsorted.py

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,29 @@
11
from typing import Literal, Union
22

33
import dpctl
4+
import dpctl.tensor as dpt
45
import dpctl.utils as du
56

67
from ._copy_utils import _empty_like_orderK
78
from ._ctors import empty
9+
from ._scalar_utils import _get_dtype, _get_queue_usm_type, _validate_dtype
810
from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy
911
from ._tensor_impl import _take as ti_take
1012
from ._tensor_impl import (
1113
default_device_index_type as ti_default_device_index_type,
1214
)
1315
from ._tensor_sorting_impl import _searchsorted_left, _searchsorted_right
14-
from ._type_utils import isdtype, result_type
16+
from ._type_utils import (
17+
_resolve_weak_types_all_py_ints,
18+
_to_device_supported_dtype,
19+
isdtype,
20+
)
1521
from ._usmarray import usm_ndarray
1622

1723

1824
def searchsorted(
1925
x1: usm_ndarray,
20-
x2: usm_ndarray,
26+
x2: Union[usm_ndarray, int, float, complex, bool],
2127
/,
2228
*,
2329
side: Literal["left", "right"] = "left",
@@ -34,8 +40,8 @@ def searchsorted(
3440
input array. Must be a one-dimensional array. If `sorter` is
3541
`None`, must be sorted in ascending order; otherwise, `sorter` must
3642
be an array of indices that sort `x1` in ascending order.
37-
x2 (usm_ndarray):
38-
array containing search values.
43+
x2 (Union[usm_ndarray, bool, int, float, complex]):
44+
search value or values.
3945
side (Literal["left", "right]):
4046
argument controlling which index is returned if a value lands
4147
exactly on an edge. If `x2` is an array of rank `N` where
@@ -56,8 +62,6 @@ def searchsorted(
5662
"""
5763
if not isinstance(x1, usm_ndarray):
5864
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}")
59-
if not isinstance(x2, usm_ndarray):
60-
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}")
6165
if sorter is not None and not isinstance(sorter, usm_ndarray):
6266
raise TypeError(
6367
f"Expected dpctl.tensor.usm_ndarray, got {type(sorter)}"
@@ -69,23 +73,39 @@ def searchsorted(
6973
"Expected either 'left' or 'right'"
7074
)
7175

72-
if sorter is None:
73-
q = du.get_execution_queue([x1.sycl_queue, x2.sycl_queue])
74-
else:
75-
q = du.get_execution_queue(
76-
[x1.sycl_queue, x2.sycl_queue, sorter.sycl_queue]
77-
)
76+
q1, x1_usm_type = x1.sycl_queue, x1.usm_type
77+
q2, x2_usm_type = _get_queue_usm_type(x2)
78+
q3 = sorter.sycl_queue if sorter is not None else None
79+
q = du.get_execution_queue(tuple(q for q in (q1, q2, q3) if q is not None))
7880
if q is None:
7981
raise du.ExecutionPlacementError(
8082
"Execution placement can not be unambiguously "
8183
"inferred from input arguments."
8284
)
8385

86+
res_usm_type = du.get_coerced_usm_type(
87+
tuple(
88+
ut
89+
for ut in (
90+
x1_usm_type,
91+
x2_usm_type,
92+
)
93+
if ut is not None
94+
)
95+
)
96+
du.validate_usm_type(res_usm_type, allow_none=False)
97+
sycl_dev = q.sycl_device
98+
8499
if x1.ndim != 1:
85100
raise ValueError("First argument array must be one-dimensional")
86101

87102
x1_dt = x1.dtype
88-
x2_dt = x2.dtype
103+
x2_dt = _get_dtype(x2, sycl_dev)
104+
if not _validate_dtype(x2_dt):
105+
raise ValueError(
106+
"dpt.searchsorted search value argument has "
107+
f"unsupported data type {x2_dt}"
108+
)
89109

90110
_manager = du.SequentialOrderManager[q]
91111
dep_evs = _manager.submitted_events
@@ -100,7 +120,7 @@ def searchsorted(
100120
"Sorter array must be one-dimension with the same "
101121
"shape as the first argument array"
102122
)
103-
res = empty(x1.shape, dtype=x1_dt, usm_type=x1.usm_type, sycl_queue=q)
123+
res = empty(x1.shape, dtype=x1_dt, usm_type=x1_usm_type, sycl_queue=q)
104124
ind = (sorter,)
105125
axis = 0
106126
wrap_out_of_bound_indices_mode = 0
@@ -116,29 +136,28 @@ def searchsorted(
116136
x1 = res
117137
_manager.add_event_pair(ht_ev, ev)
118138

119-
if x1_dt != x2_dt:
120-
dt = result_type(x1, x2)
121-
if x1_dt != dt:
122-
x1_buf = _empty_like_orderK(x1, dt)
123-
dep_evs = _manager.submitted_events
124-
ht_ev, ev = ti_copy(
125-
src=x1, dst=x1_buf, sycl_queue=q, depends=dep_evs
126-
)
127-
_manager.add_event_pair(ht_ev, ev)
128-
x1 = x1_buf
129-
if x2_dt != dt:
130-
x2_buf = _empty_like_orderK(x2, dt)
131-
dep_evs = _manager.submitted_events
132-
ht_ev, ev = ti_copy(
133-
src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs
134-
)
135-
_manager.add_event_pair(ht_ev, ev)
136-
x2 = x2_buf
139+
dt1, dt2 = _resolve_weak_types_all_py_ints(x1_dt, x2_dt, sycl_dev)
140+
dt = _to_device_supported_dtype(dpt.result_type(dt1, dt2), sycl_dev)
141+
142+
if x1_dt != dt:
143+
x1_buf = _empty_like_orderK(x1, dt)
144+
dep_evs = _manager.submitted_events
145+
ht_ev, ev = ti_copy(src=x1, dst=x1_buf, sycl_queue=q, depends=dep_evs)
146+
_manager.add_event_pair(ht_ev, ev)
147+
x1 = x1_buf
148+
149+
if not isinstance(x2, usm_ndarray):
150+
x2 = dpt.asarray(x2, dtype=dt2, usm_type=res_usm_type, sycl_queue=q)
151+
if x2.dtype != dt:
152+
x2_buf = _empty_like_orderK(x2, dt)
153+
dep_evs = _manager.submitted_events
154+
ht_ev, ev = ti_copy(src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs)
155+
_manager.add_event_pair(ht_ev, ev)
156+
x2 = x2_buf
137157

138-
dst_usm_type = du.get_coerced_usm_type([x1.usm_type, x2.usm_type])
139158
index_dt = ti_default_device_index_type(q)
140159

141-
dst = _empty_like_orderK(x2, index_dt, usm_type=dst_usm_type)
160+
dst = _empty_like_orderK(x2, index_dt, usm_type=res_usm_type)
142161

143162
dep_evs = _manager.submitted_events
144163
if side == "left":

0 commit comments

Comments
 (0)