Skip to content

Commit e750d48

Browse files
authored
add dpnp.searchsorted (#708)
* add dpnp.searchsorted
1 parent 6326cb7 commit e750d48

File tree

9 files changed

+222
-156
lines changed

9 files changed

+222
-156
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,20 @@ INP_DLLEXPORT void dpnp_eigvals_c(const void* array_in, void* result1, size_t si
456456
template <typename _DataType, typename _idx_DataType>
457457
INP_DLLEXPORT void dpnp_argsort_c(void* array, void* result, size_t size);
458458

459+
/**
460+
* @ingroup BACKEND_API
461+
* @brief math library implementation of searchsorted function
462+
*
463+
* @param [out] result Output array.
464+
* @param [in] array Input array with data.
465+
* @param [in] v Input values to insert into array.
466+
* @param [in] side Param for choosing a case of searching for elements.
467+
* @param [in] arr_size Number of elements in input arrays.
468+
* @param [in] v_size Number of elements in input values arrays.
469+
*/
470+
template<typename _DataType, typename _IndexingType>
471+
INP_DLLEXPORT void dpnp_searchsorted_c(void* result, const void* array, const void* v, bool side, const size_t arr_size, const size_t v_size);
472+
459473
/**
460474
* @ingroup BACKEND_API
461475
* @brief math library implementation of sort function

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ enum class DPNPFuncName : size_t
187187
DPNP_FN_RNG_WALD, /**< Used in numpy.random.wald() implementation */
188188
DPNP_FN_RNG_WEIBULL, /**< Used in numpy.random.weibull() implementation */
189189
DPNP_FN_RNG_ZIPF, /**< Used in numpy.random.zipf() implementation */
190+
DPNP_FN_SEARCHSORTED, /**< Used in numpy.searchsorted() implementation */
190191
DPNP_FN_SIGN, /**< Used in numpy.sign() implementation */
191192
DPNP_FN_SIN, /**< Used in numpy.sin() implementation */
192193
DPNP_FN_SINH, /**< Used in numpy.sinh() implementation */

dpnp/backend/kernels/dpnp_krnl_sorting.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,105 @@ void dpnp_partition_c(
175175
dpnp_memory_free_c(shape);
176176
}
177177

178+
template<typename _DataType, typename _IndexingType>
179+
class dpnp_searchsorted_c_kernel;
180+
181+
template<typename _DataType, typename _IndexingType>
182+
void dpnp_searchsorted_c(void* result1, const void* array1_in, const void* v1_in, bool side, const size_t arr_size, const size_t v_size)
183+
{
184+
const _DataType* arr = reinterpret_cast<const _DataType*>(array1_in);
185+
const _DataType* v = reinterpret_cast<const _DataType*>(v1_in);
186+
_IndexingType* result = reinterpret_cast<_IndexingType*>(result1);
187+
188+
if ((arr == nullptr) || (v == nullptr) || (result == nullptr))
189+
{
190+
return;
191+
}
192+
193+
if (arr_size == 0)
194+
{
195+
return;
196+
}
197+
198+
if (v_size == 0)
199+
{
200+
return;
201+
}
202+
203+
cl::sycl::range<2> gws(v_size, arr_size);
204+
auto kernel_parallel_for_func = [=](cl::sycl::id<2> global_id) {
205+
size_t i = global_id[0];
206+
size_t j = global_id[1];
207+
208+
if (j != 0)
209+
{
210+
if (side)
211+
{
212+
if (j == arr_size-1)
213+
{
214+
if (v[i] == arr[j])
215+
{
216+
result[i] = arr_size - 1;
217+
}
218+
else
219+
{
220+
if (v[i] > arr[j])
221+
{
222+
result[i] = arr_size;
223+
}
224+
}
225+
}
226+
else
227+
{
228+
if ((arr[j-1] < v[i]) && (v[i] <= arr[j]))
229+
{
230+
result[i] = j;
231+
}
232+
}
233+
}
234+
else
235+
{
236+
if (j == arr_size - 1)
237+
{
238+
if ((arr[j-1] <= v[i]) && (v[i] < arr[j]))
239+
{
240+
result[i] = arr_size - 1;
241+
}
242+
else
243+
{
244+
if (v[i] == arr[j])
245+
{
246+
result[i] = arr_size;
247+
}
248+
else
249+
{
250+
if (v[i] > arr[j])
251+
{
252+
result[i] = arr_size;
253+
}
254+
}
255+
}
256+
}
257+
else
258+
{
259+
if ((arr[j-1] <= v[i]) && (v[i] < arr[j]))
260+
{
261+
result[i] = j;
262+
}
263+
}
264+
}
265+
}
266+
};
267+
268+
auto kernel_func = [&](cl::sycl::handler& cgh) {
269+
cgh.parallel_for<class dpnp_searchsorted_c_kernel<_DataType, _IndexingType>>(gws, kernel_parallel_for_func);
270+
};
271+
272+
auto event = DPNP_QUEUE.submit(kernel_func);
273+
274+
event.wait();
275+
}
276+
178277
template <typename _DataType>
179278
class dpnp_sort_c_kernel;
180279

@@ -207,6 +306,11 @@ void func_map_init_sorting(func_map_t& fmap)
207306
fmap[DPNPFuncName::DPNP_FN_PARTITION][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_partition_c<float>};
208307
fmap[DPNPFuncName::DPNP_FN_PARTITION][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_partition_c<double>};
209308

309+
fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_searchsorted_c<int, long>};
310+
fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_searchsorted_c<long, long>};
311+
fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_searchsorted_c<float, long>};
312+
fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_searchsorted_c<double, long>};
313+
210314
fmap[DPNPFuncName::DPNP_FN_SORT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_sort_c<int>};
211315
fmap[DPNPFuncName::DPNP_FN_SORT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_sort_c<long>};
212316
fmap[DPNPFuncName::DPNP_FN_SORT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_sort_c<float>};

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
160160
DPNP_FN_RNG_WALD
161161
DPNP_FN_RNG_WEIBULL
162162
DPNP_FN_RNG_ZIPF
163+
DPNP_FN_SEARCHSORTED
163164
DPNP_FN_SIGN
164165
DPNP_FN_SIN
165166
DPNP_FN_SINH

dpnp/dpnp_algo/dpnp_algo_sorting.pyx

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@ from dpnp.dpnp_utils cimport *
4040
__all__ += [
4141
"dpnp_argsort",
4242
"dpnp_partition",
43+
"dpnp_searchsorted",
4344
"dpnp_sort"
4445
]
4546

4647

4748
ctypedef void(*fptr_dpnp_partition_t)(void * , void * , void * , const size_t , const size_t * , const size_t)
49+
ctypedef void(*fptr_dpnp_searchsorted_t)(void * , const void * , const void * , bool , const size_t , const size_t )
4850

4951

5052
cpdef dparray dpnp_argsort(dparray in_array1):
@@ -68,5 +70,24 @@ cpdef dparray dpnp_partition(dparray arr, int kth, axis=-1, kind='introselect',
6870
return result
6971

7072

73+
cpdef dparray dpnp_searchsorted(dparray arr, dparray v, side='left'):
74+
if side is 'left':
75+
side_ = True
76+
else:
77+
side_ = False
78+
79+
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(arr.dtype)
80+
81+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_SEARCHSORTED, param1_type, param1_type)
82+
83+
cdef dparray result = dparray(v.shape, dtype=dpnp.int64)
84+
85+
cdef fptr_dpnp_searchsorted_t func = <fptr_dpnp_searchsorted_t > kernel_data.ptr
86+
87+
func(arr.get_data(), v.get_data(), result.get_data(), side_, arr.size, v.size)
88+
89+
return result
90+
91+
7192
cpdef dparray dpnp_sort(dparray in_array1):
7293
return call_fptr_1in_1out(DPNP_FN_SORT, in_array1, in_array1.shape)

dpnp/dpnp_iface_sorting.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
__all__ = [
5454
'argsort',
5555
'partition',
56+
'searchsorted',
5657
'sort'
5758
]
5859

@@ -135,6 +136,39 @@ def partition(arr, kth, axis=-1, kind='introselect', order=None):
135136
return call_origin(numpy.partition, arr, kth, axis, kind, order)
136137

137138

139+
def searchsorted(arr, v, side='left', sorter=None):
140+
"""
141+
Find indices where elements should be inserted to maintain order.
142+
For full documentation refer to :obj:`numpy.searchsorted`.
143+
144+
Limitations
145+
-----------
146+
Input arrays is supported as :obj:`dpnp.ndarray`.
147+
Input array is supported only sorted.
148+
Input side is supported only values ``left``, ``right``.
149+
Parameters ``sorter`` is supported only with default values.
150+
"""
151+
if not use_origin_backend():
152+
if not isinstance(arr, dparray):
153+
pass
154+
elif dpnp.sort(arr) != arr:
155+
pass
156+
elif not isinstance(v, dparray):
157+
pass
158+
elif arr.ndim != 1:
159+
pass
160+
elif arr.dtype != v.dtype:
161+
pass
162+
elif side not in ['left', 'right']:
163+
pass
164+
elif sorter is not None:
165+
pass
166+
else:
167+
return dpnp_searchsorted(arr, v, side=side)
168+
169+
return call_origin(numpy.searchsorted, arr, v, side=side, sorter=sorter)
170+
171+
138172
def sort(x1, **kwargs):
139173
"""
140174
Return a sorted copy of an array.

0 commit comments

Comments
 (0)