Skip to content

Commit f9d06d8

Browse files
add backend to count_nonzero func (#893)
* add backend for count_nonzero func * using ptr adapter Co-authored-by: Alexander-Makaryev <[email protected]>
1 parent a7ee281 commit f9d06d8

File tree

6 files changed

+49
-12
lines changed

6 files changed

+49
-12
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,18 @@ INP_DLLEXPORT void dpnp_sum_c(void* result_out,
389389
const void* initial,
390390
const long* where);
391391

392+
/**
393+
* @ingroup BACKEND_API
394+
* @brief Custom implementation of count_nonzero function
395+
*
396+
* @param [in] array1_in Input array.
397+
* @param [out] result1_out Output array.
398+
* @param [in] size Number of elements in input arrays.
399+
*
400+
*/
401+
template <typename _DataType_input, typename _DataType_output>
402+
INP_DLLEXPORT void dpnp_count_nonzero_c(void* array1_in, void* result1_out, size_t size);
403+
392404
/**
393405
* @ingroup BACKEND_API
394406
* @brief Place of array elements

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ enum class DPNPFuncName : size_t
9090
DPNP_FN_CORRELATE, /**< Used in numpy.correlate() implementation */
9191
DPNP_FN_COS, /**< Used in numpy.cos() implementation */
9292
DPNP_FN_COSH, /**< Used in numpy.cosh() implementation */
93+
DPNP_FN_COUNT_NONZERO, /**< Used in numpy.count_nonzero() implementation */
9394
DPNP_FN_COV, /**< Used in numpy.cov() implementation */
9495
DPNP_FN_CROSS, /**< Used in numpy.cross() implementation */
9596
DPNP_FN_CUMPROD, /**< Used in numpy.cumprod() implementation */

dpnp/backend/kernels/dpnp_krnl_statistics.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,33 @@ void dpnp_cov_c(void* array1_in, void* result1, size_t nrows, size_t ncols)
150150
return;
151151
}
152152

153+
template <typename _DataType_input, typename _DataType_output>
154+
void dpnp_count_nonzero_c(void* array1_in, void* result1_out, size_t size)
155+
{
156+
if (array1_in == nullptr)
157+
{
158+
return;
159+
}
160+
161+
DPNPC_ptr_adapter<_DataType_input> input1_ptr(array1_in, size, true);
162+
DPNPC_ptr_adapter<_DataType_output> result_ptr(result1_out, 1, true, true);
163+
_DataType_input* array1 = input1_ptr.get_ptr();
164+
_DataType_output* result1 = result_ptr.get_ptr();
165+
166+
167+
result1[0] = 0;
168+
169+
for (size_t i = 0; i < size; ++i)
170+
{
171+
if (array1[i] != 0)
172+
{
173+
result1[0] += 1;
174+
}
175+
}
176+
177+
return;
178+
}
179+
153180
template <typename _DataType>
154181
class dpnp_max_c_kernel;
155182

@@ -731,6 +758,11 @@ void func_map_init_statistics(func_map_t& fmap)
731758
fmap[DPNPFuncName::DPNP_FN_CORRELATE][eft_DBL][eft_DBL] = {eft_DBL,
732759
(void*)dpnp_correlate_c<double, double, double>};
733760

761+
fmap[DPNPFuncName::DPNP_FN_COUNT_NONZERO][eft_INT][eft_INT] = {eft_LNG, (void*)dpnp_count_nonzero_c<int, long>};
762+
fmap[DPNPFuncName::DPNP_FN_COUNT_NONZERO][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_count_nonzero_c<long, long>};
763+
fmap[DPNPFuncName::DPNP_FN_COUNT_NONZERO][eft_FLT][eft_FLT] = {eft_LNG, (void*)dpnp_count_nonzero_c<float, long>};
764+
fmap[DPNPFuncName::DPNP_FN_COUNT_NONZERO][eft_DBL][eft_DBL] = {eft_LNG, (void*)dpnp_count_nonzero_c<double, long>};
765+
734766
fmap[DPNPFuncName::DPNP_FN_COV][eft_INT][eft_INT] = {eft_DBL, (void*)dpnp_cov_c<double>};
735767
fmap[DPNPFuncName::DPNP_FN_COV][eft_LNG][eft_LNG] = {eft_DBL, (void*)dpnp_cov_c<double>};
736768
fmap[DPNPFuncName::DPNP_FN_COV][eft_FLT][eft_FLT] = {eft_DBL, (void*)dpnp_cov_c<double>};

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
6969
DPNP_FN_COS
7070
DPNP_FN_COSH
7171
DPNP_FN_COV
72+
DPNP_FN_COUNT_NONZERO
7273
DPNP_FN_CROSS
7374
DPNP_FN_CUMPROD
7475
DPNP_FN_CUMSUM

dpnp/dpnp_algo/dpnp_algo_counting.pyx

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,5 @@ __all__ += [
3939
]
4040

4141

42-
cpdef utils.dpnp_descriptor dpnp_count_nonzero(utils.dpnp_descriptor in_array1):
43-
cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py((1,), dpnp.int64, None)
44-
45-
count = 0
46-
for i in range(in_array1.size):
47-
if in_array1.get_pyobj()[i] != 0:
48-
count += 1
49-
50-
result.get_pyobj()[0] = count
51-
52-
return result
42+
cpdef utils.dpnp_descriptor dpnp_count_nonzero(utils.dpnp_descriptor x1):
43+
return call_fptr_1in_1out(DPNP_FN_COUNT_NONZERO, x1, (1,))

dpnp/dpnp_iface_counting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def count_nonzero(x1, axis=None, *, keepdims=False):
5858
5959
Limitations
6060
-----------
61-
Parameter ``in_array1`` is supported as :obj:`dpnp.ndarray`.
61+
Parameter ``x1`` is supported as :obj:`dpnp.ndarray`.
6262
Otherwise the function will be executed sequentially on CPU.
6363
Parameter ``axis`` is supported only with default value `None`.
6464
Parameter ``keepdims`` is supported only with default value `False`.

0 commit comments

Comments
 (0)