Skip to content

Commit 10ccd07

Browse files
Move choose from cython to backend (#928)
* Move choose from cython to backend
1 parent e7be6a6 commit 10ccd07

File tree

7 files changed

+117
-27
lines changed

7 files changed

+117
-27
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,21 @@ INP_DLLEXPORT void dpnp_cov_c(void* array1_in, void* result1, size_t nrows, size
600600
template <typename _DataType>
601601
INP_DLLEXPORT void dpnp_det_c(void* array1_in, void* result1, size_t* shape, size_t ndim);
602602

603+
/**
604+
* @ingroup BACKEND_API
605+
* @brief math library implementation of take function
606+
*
607+
* @param [out] result Output array.
608+
* @param [in] array Input array with data.
609+
* @param [in] choices Choice arrays.
610+
* @param [in] size Input array size.
611+
* @param [in] choices_size Choices size.
612+
* @param [in] choice_size Choices size.
613+
*/
614+
template <typename _DataType1, typename _DataType2>
615+
INP_DLLEXPORT void
616+
dpnp_choose_c(void* result1, void* array1_in, void** choices, size_t size, size_t choices_size, size_t choice_size);
617+
603618
/**
604619
* @ingroup BACKEND_API
605620
* @brief math library implementation of det function

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ enum class DPNPFuncName : size_t
8383
DPNP_FN_CEIL, /**< Used in numpy.ceil() implementation */
8484
DPNP_FN_CHOLESKY, /**< Used in numpy.linalg.cholesky() implementation */
8585
DPNP_FN_CONJIGUATE, /**< Used in numpy.conjugate() implementation */
86+
DPNP_FN_CHOOSE, /**< Used in numpy.choose() implementation */
8687
DPNP_FN_COPY, /**< Used in numpy.copy() implementation */
8788
DPNP_FN_COPYSIGN, /**< Used in numpy.copysign() implementation */
8889
DPNP_FN_COPYTO, /**< Used in numpy.copyto() implementation */

dpnp/backend/kernels/dpnp_krnl_indexing.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,50 @@
3232
#include "dpnpc_memory_adapter.hpp"
3333
#include "queue_sycl.hpp"
3434

35+
template <typename _DataType1, typename _DataType2>
36+
class dpnp_choose_c_kernel;
37+
38+
template <typename _DataType1, typename _DataType2>
39+
void dpnp_choose_c(
40+
void* result1, void* array1_in, void** choices1, size_t size, size_t choices_size, size_t choice_size)
41+
{
42+
if ((array1_in == nullptr) || (result1 == nullptr) || (choices1 == nullptr))
43+
{
44+
return;
45+
}
46+
if (!size || !choices_size || !choice_size)
47+
{
48+
return;
49+
}
50+
DPNPC_ptr_adapter<_DataType1> input1_ptr(array1_in, size);
51+
_DataType1* array_in = input1_ptr.get_ptr();
52+
53+
DPNPC_ptr_adapter<_DataType2*> choices_ptr(choices1, choices_size);
54+
_DataType2** choices = choices_ptr.get_ptr();
55+
56+
for (size_t i = 0; i < choices_size; ++i)
57+
{
58+
DPNPC_ptr_adapter<_DataType2> choice_ptr(choices[i], choice_size);
59+
choices[i] = choice_ptr.get_ptr();
60+
}
61+
62+
DPNPC_ptr_adapter<_DataType2> result1_ptr(result1, size, false, true);
63+
_DataType2* result = result1_ptr.get_ptr();
64+
65+
cl::sycl::range<1> gws(size);
66+
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
67+
const size_t idx = global_id[0];
68+
result[idx] = choices[array_in[idx]][idx];
69+
};
70+
71+
auto kernel_func = [&](cl::sycl::handler& cgh) {
72+
cgh.parallel_for<class dpnp_choose_c_kernel<_DataType1, _DataType2>>(gws, kernel_parallel_for_func);
73+
};
74+
75+
cl::sycl::event event = DPNP_QUEUE.submit(kernel_func);
76+
event.wait();
77+
}
78+
3579
template <typename _DataType>
3680
class dpnp_diag_indices_c_kernel;
3781

@@ -521,6 +565,15 @@ void dpnp_take_c(void* array1_in, const size_t array1_size, void* indices1, void
521565

522566
void func_map_init_indexing_func(func_map_t& fmap)
523567
{
568+
fmap[DPNPFuncName::DPNP_FN_CHOOSE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_choose_c<int, int>};
569+
fmap[DPNPFuncName::DPNP_FN_CHOOSE][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_choose_c<int, long>};
570+
fmap[DPNPFuncName::DPNP_FN_CHOOSE][eft_INT][eft_FLT] = {eft_FLT, (void*)dpnp_choose_c<int, float>};
571+
fmap[DPNPFuncName::DPNP_FN_CHOOSE][eft_INT][eft_DBL] = {eft_DBL, (void*)dpnp_choose_c<int, double>};
572+
fmap[DPNPFuncName::DPNP_FN_CHOOSE][eft_LNG][eft_INT] = {eft_INT, (void*)dpnp_choose_c<long, int>};
573+
fmap[DPNPFuncName::DPNP_FN_CHOOSE][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_choose_c<long, long>};
574+
fmap[DPNPFuncName::DPNP_FN_CHOOSE][eft_LNG][eft_FLT] = {eft_FLT, (void*)dpnp_choose_c<long, float>};
575+
fmap[DPNPFuncName::DPNP_FN_CHOOSE][eft_LNG][eft_DBL] = {eft_DBL, (void*)dpnp_choose_c<long, double>};
576+
524577
fmap[DPNPFuncName::DPNP_FN_DIAG_INDICES][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_diag_indices_c<int>};
525578
fmap[DPNPFuncName::DPNP_FN_DIAG_INDICES][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_diag_indices_c<long>};
526579
fmap[DPNPFuncName::DPNP_FN_DIAG_INDICES][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_diag_indices_c<float>};

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
6060
DPNP_FN_CBRT
6161
DPNP_FN_CEIL
6262
DPNP_FN_CHOLESKY
63+
DPNP_FN_CHOOSE
6364
DPNP_FN_CONJIGUATE
6465
DPNP_FN_COPY
6566
DPNP_FN_COPYSIGN

dpnp/dpnp_algo/dpnp_algo.pyx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ and the rest of the library
3333
"""
3434

3535
from libc.time cimport time, time_t
36+
from libcpp.vector cimport vector
3637
import dpnp
3738
import dpnp.config as config
3839
import dpnp.dpnp_utils as utils_py

dpnp/dpnp_algo/dpnp_algo_indexing.pyx

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ __all__ += [
5454
"dpnp_triu_indices_from"
5555
]
5656

57+
ctypedef void(*fptr_dpnp_choose_t)(void * , void * , void **, size_t, size_t, size_t)
5758
ctypedef void(*fptr_dpnp_diag_indices)(void*, size_t)
5859
ctypedef void(*custom_indexing_2in_1out_func_ptr_t)(void *, const size_t, void * , void * , size_t)
5960
ctypedef void(*custom_indexing_2in_1out_func_ptr_t_)(void * , const size_t, void * , const size_t, size_t * , size_t * , const size_t)
@@ -65,12 +66,33 @@ ctypedef void(*custom_indexing_6in_func_ptr_t)(void *, void * , void * , const s
6566
ctypedef void(*fptr_dpnp_nonzero_t)(const void * , void * , const size_t, const size_t * , const size_t , const size_t)
6667

6768

68-
cpdef utils.dpnp_descriptor dpnp_choose(object input, list choices):
69-
cdef shape_type_c obj_shape = utils._object_to_tuple(len(input))
70-
cdef utils.dpnp_descriptor res_array = utils_py.create_output_descriptor_py(obj_shape, choices[0].dtype, None)
69+
cpdef utils.dpnp_descriptor dpnp_choose(utils.dpnp_descriptor input, list choices1):
70+
cdef vector[void * ] choices
71+
cdef utils.dpnp_descriptor choice
72+
for desc in choices1:
73+
choice = desc
74+
choices.push_back(choice.get_data())
75+
76+
cdef shape_type_c input_shape = input.shape
77+
cdef size_t choice_size = choices1[0].size
78+
79+
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(input.dtype)
80+
81+
cdef DPNPFuncType param2_type = dpnp_dtype_to_DPNPFuncType(choices1[0].dtype)
82+
83+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_CHOOSE, param1_type, param2_type)
84+
85+
cdef utils.dpnp_descriptor res_array = utils.create_output_descriptor(input_shape, kernel_data.return_type, None)
86+
87+
cdef fptr_dpnp_choose_t func = <fptr_dpnp_choose_t> kernel_data.ptr
88+
89+
func(res_array.get_data(),
90+
input.get_data(),
91+
choices.data(),
92+
input_shape[0],
93+
choices.size(),
94+
choice_size)
7195

72-
for i in range(len(input)):
73-
res_array.get_pyobj()[i] = (choices[input[i]])[i]
7496
return res_array
7597

7698

@@ -299,7 +321,7 @@ cpdef utils.dpnp_descriptor dpnp_select(list condlist, list choicelist, default)
299321

300322
cpdef utils.dpnp_descriptor dpnp_take(utils.dpnp_descriptor input, utils.dpnp_descriptor indices):
301323
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(input.dtype)
302-
324+
303325
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TAKE, param1_type, param1_type)
304326

305327
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(indices.shape, kernel_data.return_type, None)

dpnp/dpnp_iface_indexing.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -81,35 +81,32 @@ def choose(x1, choices, out=None, mode='raise'):
8181
--------
8282
:obj:`take_along_axis` : Preferable if choices is an array.
8383
"""
84-
if not use_origin_backend(x1):
85-
if not isinstance(x1, list):
86-
pass
87-
elif not isinstance(choices, list):
84+
x1_desc = dpnp.get_dpnp_descriptor(x1)
85+
86+
choices_list =[]
87+
for choice in choices:
88+
choices_list.append(dpnp.get_dpnp_descriptor(choice))
89+
90+
if x1_desc:
91+
if any(not desc for desc in choices_list):
8892
pass
8993
elif out is not None:
9094
pass
9195
elif mode != 'raise':
9296
pass
97+
elif any(not choices[0].dtype == choice.dtype for choice in choices):
98+
pass
99+
elif not len(choices_list):
100+
pass
93101
else:
94-
val = True
95-
len_ = len(x1)
96-
size_ = choices[0].size
97-
for i in range(len(choices)):
98-
if choices[i].size != size_ or choices[i].size != len_:
99-
val = False
100-
break
101-
if not val:
102+
size = x1_desc.size
103+
choices_size = choices_list[0].size
104+
if any(choice.size != choices_size or choice.size != size for choice in choices):
105+
pass
106+
elif any(x >= choices_size for x in dpnp.asnumpy(x1)):
102107
pass
103108
else:
104-
val = True
105-
for i in range(len_):
106-
if x1[i] >= size_:
107-
val = False
108-
break
109-
if not val:
110-
pass
111-
else:
112-
return dpnp_choose(x1, choices).get_pyobj()
109+
return dpnp_choose(x1_desc, choices_list).get_pyobj()
113110

114111
return call_origin(numpy.choose, x1, choices, out, mode)
115112

0 commit comments

Comments
 (0)