Skip to content

Commit be4cef5

Browse files
authored
add backend at around function (#726)
* add backend for around func * add limitations
1 parent 949bdba commit be4cef5

File tree

6 files changed

+95
-13
lines changed

6 files changed

+95
-13
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,21 @@ INP_DLLEXPORT void dpnp_argmax_c(void* array, void* result, size_t size);
677677
template <typename _DataType, typename _idx_DataType>
678678
INP_DLLEXPORT void dpnp_argmin_c(void* array, void* result, size_t size);
679679

680+
/**
681+
* @ingroup BACKEND_API
682+
* @brief math library implementation of around function
683+
*
684+
* @param [in] input_in Input array with data.
685+
* @param [out] result_out Output array with indeces.
686+
* @param [in] input_size Number of elements in input arrays.
687+
* @param [in] decimals Number of decimal places to round. Support only with default value 0.
688+
*/
689+
template <typename _DataType>
690+
INP_DLLEXPORT void dpnp_around_c(const void* input_in,
691+
void* result_out,
692+
const size_t input_size,
693+
const int decimals);
694+
680695
/**
681696
* @ingroup BACKEND_API
682697
* @brief math library implementation of std function

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ enum class DPNPFuncName : size_t
7373
DPNP_FN_ARGMAX, /**< Used in numpy.argmax() implementation */
7474
DPNP_FN_ARGMIN, /**< Used in numpy.argmin() implementation */
7575
DPNP_FN_ARGSORT, /**< Used in numpy.argsort() implementation */
76+
DPNP_FN_AROUND, /**< Used in numpy.around() implementation */
7677
DPNP_FN_ASTYPE, /**< Used in numpy.astype() implementation */
7778
DPNP_FN_BITWISE_AND, /**< Used in numpy.bitwise_and() implementation */
7879
DPNP_FN_BITWISE_OR, /**< Used in numpy.bitwise_or() implementation */

dpnp/backend/kernels/dpnp_krnl_mathematical.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,51 @@
3232
#include "dpnp_utils.hpp"
3333
#include "queue_sycl.hpp"
3434

35+
template <typename _KernelNameSpecialization>
36+
class dpnp_around_c_kernel;
37+
38+
template <typename _DataType>
39+
void dpnp_around_c(const void* input_in,
40+
void* result_out,
41+
const size_t input_size,
42+
const int decimals)
43+
{
44+
(void)decimals;
45+
46+
if (!input_size)
47+
{
48+
return;
49+
}
50+
51+
cl::sycl::event event;
52+
_DataType* input = reinterpret_cast<_DataType*>(const_cast<void*>(input_in));
53+
_DataType* result = reinterpret_cast<_DataType*>(result_out);
54+
55+
if constexpr(std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
56+
{
57+
event = oneapi::mkl::vm::rint(DPNP_QUEUE, input_size, input, result);
58+
}
59+
else
60+
{
61+
cl::sycl::range<1> gws(input_size);
62+
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) {
63+
size_t i = global_id[0];
64+
{
65+
result[i] = std::rint(input[i]);
66+
}
67+
};
68+
69+
auto kernel_func = [&](cl::sycl::handler& cgh) {
70+
cgh.parallel_for<class dpnp_around_c_kernel<_DataType>>(
71+
gws, kernel_parallel_for_func);
72+
};
73+
74+
event = DPNP_QUEUE.submit(kernel_func);
75+
}
76+
77+
event.wait();
78+
}
79+
3580
template <typename _KernelNameSpecialization>
3681
class dpnp_elemwise_absolute_c_kernel;
3782

@@ -394,6 +439,11 @@ void func_map_init_mathematical(func_map_t& fmap)
394439
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_elemwise_absolute_c<float>};
395440
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_elemwise_absolute_c<double>};
396441

442+
fmap[DPNPFuncName::DPNP_FN_AROUND][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_around_c<int>};
443+
fmap[DPNPFuncName::DPNP_FN_AROUND][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_around_c<long>};
444+
fmap[DPNPFuncName::DPNP_FN_AROUND][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_around_c<float>};
445+
fmap[DPNPFuncName::DPNP_FN_AROUND][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_around_c<double>};
446+
397447
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_cross_c<int, int, int>};
398448
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_cross_c<long, int, long>};
399449
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_INT][eft_FLT] = {eft_DBL, (void*)dpnp_cross_c<double, int, float>};

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
4646
DPNP_FN_ARGMAX
4747
DPNP_FN_ARGMIN
4848
DPNP_FN_ARGSORT
49+
DPNP_FN_AROUND
4950
DPNP_FN_ASTYPE
5051
DPNP_FN_BITWISE_AND
5152
DPNP_FN_BITWISE_OR

dpnp/dpnp_algo/dpnp_algo_mathematical.pyx

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ __all__ += [
8080
ctypedef void(*fptr_custom_elemwise_absolute_1in_1out_t)(void * , void * , size_t)
8181
ctypedef void(*fptr_1in_2out_t)(void * , void * , void * , size_t)
8282
ctypedef void(*ftpr_custom_trapz_2in_1out_with_2size_t)(void *, void * , void * , double, size_t, size_t)
83+
ctypedef void(*ftpr_custom_around_1in_1out_t)(const void *, void *, const size_t, const int)
8384

8485

8586
cpdef dparray dpnp_absolute(dparray input):
@@ -111,16 +112,19 @@ cpdef dparray dpnp_arctan2(object x1_obj, object x2_obj, object dtype=None, dpar
111112
return call_fptr_2in_1out(DPNP_FN_ARCTAN2, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
112113

113114

114-
cpdef dpnp_around(dparray a, decimals, out):
115-
cdef dparray result
115+
cpdef dpnp_around(dparray x1, int decimals):
116116

117-
if out is None:
118-
result = dparray(a.shape, dtype=a.dtype)
119-
else:
120-
result = out
117+
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)
121118

122-
for i in range(result.size):
123-
result[i] = round(a[i], decimals)
119+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_AROUND, param1_type, param1_type)
120+
121+
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
122+
123+
cdef dparray result = dparray(x1.shape, dtype=result_type)
124+
125+
cdef ftpr_custom_around_1in_1out_t func = <ftpr_custom_around_1in_1out_t > kernel_data.ptr
126+
127+
func(x1.get_data(), result.get_data(), x1.size, decimals)
124128

125129
return result
126130

dpnp/dpnp_iface_mathematical.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,19 @@ def add(x1, x2, dtype=None, out=None, where=True, **kwargs):
213213
return call_origin(numpy.add, x1, x2, dtype=dtype, out=out, where=where, **kwargs)
214214

215215

216-
def around(a, decimals=0, out=None):
216+
def around(x1, decimals=0, out=None):
217217
"""
218218
Evenly round to the given number of decimals.
219219
220220
For full documentation refer to :obj:`numpy.around`.
221221
222+
Limitations
223+
-----------
224+
Parameters ``x1`` is supported as :obj:`dpnp.ndarray`.
225+
Parameters ``decimals`` and ``out`` are supported with their default values.
226+
Otherwise the functions will be executed sequentially on CPU.
227+
Input array data types are limited by supported DPNP :ref:`Data types`.
228+
222229
Examples
223230
--------
224231
>>> import dpnp as np
@@ -235,13 +242,17 @@ def around(a, decimals=0, out=None):
235242
236243
"""
237244

238-
if not use_origin_backend(a):
239-
if not isinstance(a, dparray):
245+
if not use_origin_backend(x1):
246+
if not isinstance(x1, dparray):
247+
pass
248+
elif out is not None:
249+
pass
250+
elif decimals != 0:
240251
pass
241252
else:
242-
return dpnp_around(a, decimals, out)
253+
return dpnp_around(x1, decimals)
243254

244-
return call_origin(numpy.around, a, decimals, out)
255+
return call_origin(numpy.around, x1, decimals=decimals, out=out)
245256

246257

247258
def ceil(x1, **kwargs):

0 commit comments

Comments
 (0)