Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 53 additions & 34 deletions dpctl/tensor/_searchsorted.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
from typing import Literal, Union

import dpctl
import dpctl.tensor as dpt
import dpctl.utils as du

from ._copy_utils import _empty_like_orderK
from ._ctors import empty
from ._scalar_utils import _get_dtype, _get_queue_usm_type, _validate_dtype
from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy
from ._tensor_impl import _take as ti_take
from ._tensor_impl import (
default_device_index_type as ti_default_device_index_type,
)
from ._tensor_sorting_impl import _searchsorted_left, _searchsorted_right
from ._type_utils import isdtype, result_type
from ._type_utils import (
_resolve_weak_types_all_py_ints,
_to_device_supported_dtype,
isdtype,
)
from ._usmarray import usm_ndarray


def searchsorted(
x1: usm_ndarray,
x2: usm_ndarray,
x2: Union[usm_ndarray, int, float, complex, bool],
/,
*,
side: Literal["left", "right"] = "left",
Expand All @@ -34,8 +40,8 @@ def searchsorted(
input array. Must be a one-dimensional array. If `sorter` is
`None`, must be sorted in ascending order; otherwise, `sorter` must
be an array of indices that sort `x1` in ascending order.
x2 (usm_ndarray):
array containing search values.
x2 (Union[usm_ndarray, bool, int, float, complex]):
search value or values.
side (Literal["left", "right]):
argument controlling which index is returned if a value lands
exactly on an edge. If `x2` is an array of rank `N` where
Expand All @@ -56,8 +62,6 @@ def searchsorted(
"""
if not isinstance(x1, usm_ndarray):
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}")
if not isinstance(x2, usm_ndarray):
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}")
if sorter is not None and not isinstance(sorter, usm_ndarray):
raise TypeError(
f"Expected dpctl.tensor.usm_ndarray, got {type(sorter)}"
Expand All @@ -69,23 +73,39 @@ def searchsorted(
"Expected either 'left' or 'right'"
)

if sorter is None:
q = du.get_execution_queue([x1.sycl_queue, x2.sycl_queue])
else:
q = du.get_execution_queue(
[x1.sycl_queue, x2.sycl_queue, sorter.sycl_queue]
)
q1, x1_usm_type = x1.sycl_queue, x1.usm_type
q2, x2_usm_type = _get_queue_usm_type(x2)
q3 = sorter.sycl_queue if sorter is not None else None
q = du.get_execution_queue(tuple(q for q in (q1, q2, q3) if q is not None))
if q is None:
raise du.ExecutionPlacementError(
"Execution placement can not be unambiguously "
"inferred from input arguments."
)

res_usm_type = du.get_coerced_usm_type(
tuple(
ut
for ut in (
x1_usm_type,
x2_usm_type,
)
if ut is not None
)
)
du.validate_usm_type(res_usm_type, allow_none=False)
sycl_dev = q.sycl_device

if x1.ndim != 1:
raise ValueError("First argument array must be one-dimensional")

x1_dt = x1.dtype
x2_dt = x2.dtype
x2_dt = _get_dtype(x2, sycl_dev)
if not _validate_dtype(x2_dt):
raise ValueError(
"dpt.searchsorted search value argument has "
f"unsupported data type {x2_dt}"
)

_manager = du.SequentialOrderManager[q]
dep_evs = _manager.submitted_events
Expand All @@ -100,7 +120,7 @@ def searchsorted(
"Sorter array must be one-dimension with the same "
"shape as the first argument array"
)
res = empty(x1.shape, dtype=x1_dt, usm_type=x1.usm_type, sycl_queue=q)
res = empty(x1.shape, dtype=x1_dt, usm_type=x1_usm_type, sycl_queue=q)
ind = (sorter,)
axis = 0
wrap_out_of_bound_indices_mode = 0
Expand All @@ -116,29 +136,28 @@ def searchsorted(
x1 = res
_manager.add_event_pair(ht_ev, ev)

if x1_dt != x2_dt:
dt = result_type(x1, x2)
if x1_dt != dt:
x1_buf = _empty_like_orderK(x1, dt)
dep_evs = _manager.submitted_events
ht_ev, ev = ti_copy(
src=x1, dst=x1_buf, sycl_queue=q, depends=dep_evs
)
_manager.add_event_pair(ht_ev, ev)
x1 = x1_buf
if x2_dt != dt:
x2_buf = _empty_like_orderK(x2, dt)
dep_evs = _manager.submitted_events
ht_ev, ev = ti_copy(
src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs
)
_manager.add_event_pair(ht_ev, ev)
x2 = x2_buf
dt1, dt2 = _resolve_weak_types_all_py_ints(x1_dt, x2_dt, sycl_dev)
dt = _to_device_supported_dtype(dpt.result_type(dt1, dt2), sycl_dev)

if x1_dt != dt:
x1_buf = _empty_like_orderK(x1, dt)
dep_evs = _manager.submitted_events
ht_ev, ev = ti_copy(src=x1, dst=x1_buf, sycl_queue=q, depends=dep_evs)
_manager.add_event_pair(ht_ev, ev)
x1 = x1_buf

if not isinstance(x2, usm_ndarray):
x2 = dpt.asarray(x2, dtype=dt2, usm_type=res_usm_type, sycl_queue=q)
if x2.dtype != dt:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if x2.dtype != dt:
elif x2.dtype != dt:

x2_buf = _empty_like_orderK(x2, dt)
dep_evs = _manager.submitted_events
ht_ev, ev = ti_copy(src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copying x1 and x2 to the buffers might be done in parallel.
While currently it assumes a sequence execution order: x1 must be copied first and then either x2 is cased to usm_ndarray or x2 is copying to the buffer (but always once the x1 copy kernel is completed).

_manager.add_event_pair(ht_ev, ev)
x2 = x2_buf

dst_usm_type = du.get_coerced_usm_type([x1.usm_type, x2.usm_type])
index_dt = ti_default_device_index_type(q)

dst = _empty_like_orderK(x2, index_dt, usm_type=dst_usm_type)
dst = _empty_like_orderK(x2, index_dt, usm_type=res_usm_type)

dep_evs = _manager.submitted_events
if side == "left":
Expand Down
124 changes: 81 additions & 43 deletions dpctl/tests/test_usm_ndarray_searchsorted.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
# Data Parallel Control (dpctl)
#
# Copyright 2020-2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ctypes

import numpy as np
import pytest

Expand All @@ -7,6 +25,30 @@

from .helper import get_queue_or_skip, skip_if_dtype_not_supported

_integer_dtypes = [
"i1",
"u1",
"i2",
"u2",
"i4",
"u4",
"i8",
"u8",
]

_floating_dtypes = [
"f2",
"f4",
"f8",
]

_complex_dtypes = [
"c8",
"c16",
]

_all_dtypes = ["?"] + _integer_dtypes + _floating_dtypes + _complex_dtypes


def _check(hay_stack, needles, needles_np):
assert hay_stack.dtype == needles.dtype
Expand Down Expand Up @@ -73,19 +115,7 @@ def test_searchsorted_strided_bool():
)


@pytest.mark.parametrize(
"idt",
[
dpt.int8,
dpt.uint8,
dpt.int16,
dpt.uint16,
dpt.int32,
dpt.uint32,
dpt.int64,
dpt.uint64,
],
)
@pytest.mark.parametrize("idt", _integer_dtypes)
def test_searchsorted_contig_int(idt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(idt, q)
Expand All @@ -105,19 +135,7 @@ def test_searchsorted_contig_int(idt):
)


@pytest.mark.parametrize(
"idt",
[
dpt.int8,
dpt.uint8,
dpt.int16,
dpt.uint16,
dpt.int32,
dpt.uint32,
dpt.int64,
dpt.uint64,
],
)
@pytest.mark.parametrize("idt", _integer_dtypes)
def test_searchsorted_strided_int(idt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(idt, q)
Expand All @@ -144,12 +162,12 @@ def _add_extended_fp(array):
array[-1] = dpt.nan


@pytest.mark.parametrize("idt", [dpt.float16, dpt.float32, dpt.float64])
def test_searchsorted_contig_fp(idt):
@pytest.mark.parametrize("fdt", _floating_dtypes)
def test_searchsorted_contig_fp(fdt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(idt, q)
skip_if_dtype_not_supported(fdt, q)

dt = dpt.dtype(idt)
dt = dpt.dtype(fdt)

hay_stack = dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True)
_add_extended_fp(hay_stack)
Expand All @@ -165,12 +183,12 @@ def test_searchsorted_contig_fp(idt):
)


@pytest.mark.parametrize("idt", [dpt.float16, dpt.float32, dpt.float64])
def test_searchsorted_strided_fp(idt):
@pytest.mark.parametrize("fdt", _floating_dtypes)
def test_searchsorted_strided_fp(fdt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(idt, q)
skip_if_dtype_not_supported(fdt, q)

dt = dpt.dtype(idt)
dt = dpt.dtype(fdt)

hay_stack = dpt.repeat(
dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True), 4
Expand Down Expand Up @@ -213,12 +231,12 @@ def _add_extended_cfp(array):
return dpt.sort(dpt.concat((ev, array)))


@pytest.mark.parametrize("idt", [dpt.complex64, dpt.complex128])
def test_searchsorted_contig_cfp(idt):
@pytest.mark.parametrize("cdt", _complex_dtypes)
def test_searchsorted_contig_cfp(cdt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(idt, q)
skip_if_dtype_not_supported(cdt, q)

dt = dpt.dtype(idt)
dt = dpt.dtype(cdt)

hay_stack = dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True)
hay_stack = _add_extended_cfp(hay_stack)
Expand All @@ -233,12 +251,12 @@ def test_searchsorted_contig_cfp(idt):
)


@pytest.mark.parametrize("idt", [dpt.complex64, dpt.complex128])
def test_searchsorted_strided_cfp(idt):
@pytest.mark.parametrize("cdt", _complex_dtypes)
def test_searchsorted_strided_cfp(cdt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(idt, q)
skip_if_dtype_not_supported(cdt, q)

dt = dpt.dtype(idt)
dt = dpt.dtype(cdt)

hay_stack = dpt.repeat(
dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True), 4
Expand Down Expand Up @@ -285,7 +303,7 @@ def test_searchsorted_validation():
x1 = dpt.arange(10, dtype="i4")
except dpctl.SyclDeviceCreationError:
pytest.skip("Default device could not be created")
with pytest.raises(TypeError):
with pytest.raises(ValueError):
dpt.searchsorted(x1, None)
with pytest.raises(TypeError):
dpt.searchsorted(x1, x1, sorter=dict())
Expand Down Expand Up @@ -375,3 +393,23 @@ def test_searchsorted_strided_scalar_needle():
needles = dpt.asarray(needles_np)

_check(hay_stack, needles, needles_np)


@pytest.mark.parametrize("dt", _all_dtypes)
def test_searchsorted_py_scalars(dt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dt, q)

x = dpt.zeros(10, dtype=dt, sycl_queue=q)
py_zeros = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to parametrize py_zeros with below values instead of iterating over the loop for sc in py_zeros?

bool(0),
int(0),
float(0),
complex(0),
np.float32(0),
ctypes.c_int(0),
)
for sc in py_zeros:
r1 = dpt.searchsorted(x, sc)
assert isinstance(r1, dpt.usm_ndarray)
assert r1.shape == ()