Skip to content

Commit 3a43a96

Browse files
Move diag_indices from cython to backend
* add c-function for dpnp_diag_indoces * Change dpnp_diag_indices * Fix parametrs of dpnp_diag_indices_c * Fix dpnp_algo.pyx * Fix dpnp_algo_indexing.pyx * Update dpnp/dpnp_algo/dpnp_algo_indexing.pyx Co-authored-by: densmirn <[email protected]> * Change dpnp_diag_indices * Update dpnp_krnl_indexing.cpp * Update dpnp/dpnp_algo/dpnp_algo_indexing.pyx Co-authored-by: densmirn <[email protected]> Co-authored-by: densmirn <[email protected]> Co-authored-by: densmirn <[email protected]>
1 parent 0a03265 commit 3a43a96

File tree

5 files changed

+42
-6
lines changed

5 files changed

+42
-6
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,16 @@ template <typename _DataType>
579579
INP_DLLEXPORT void dpnp_diag_c(
580580
void* array, void* result, const int k, size_t* shape, size_t* res_shape, const size_t ndim, const size_t res_ndim);
581581

582+
/**
583+
* @ingroup BACKEND_API
584+
* @brief math library implementation of diagonal function
585+
*
586+
* @param [out] result Output array.
587+
* @param [in] size Size of array.
588+
*/
589+
template <typename _DataType>
590+
INP_DLLEXPORT void dpnp_diag_indices_c(void* result1, size_t size);
591+
582592
/**
583593
* @ingroup BACKEND_API
584594
* @brief math library implementation of diagonal function

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ enum class DPNPFuncName : size_t
9696
DPNP_FN_DEGREES, /**< Used in numpy.degrees() implementation */
9797
DPNP_FN_DET, /**< Used in numpy.linalg.det() implementation */
9898
DPNP_FN_DIAG, /**< Used in numpy.diag() implementation */
99+
DPNP_FN_DIAG_INDICES, /**< Used in numpy.diag_indices() implementation */
99100
DPNP_FN_DIAGONAL, /**< Used in numpy.diagonal() implementation */
100101
DPNP_FN_DIVIDE, /**< Used in numpy.divide() implementation */
101102
DPNP_FN_DOT, /**< Used in numpy.dot() implementation */

dpnp/backend/kernels/dpnp_krnl_indexing.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@
3232
#include "dpnpc_memory_adapter.hpp"
3333
#include "queue_sycl.hpp"
3434

35+
template <typename _DataType>
36+
class dpnp_diag_indices_c_kernel;
37+
38+
template <typename _DataType>
39+
void dpnp_diag_indices_c(void* result1, size_t size)
40+
{
41+
dpnp_arange_c<_DataType>(0, 1, result1, size);
42+
}
43+
3544
template <typename _DataType>
3645
class dpnp_diagonal_c_kernel;
3746

@@ -503,6 +512,11 @@ void dpnp_take_c(void* array1_in, const size_t array1_size, void* indices1, void
503512

504513
void func_map_init_indexing_func(func_map_t& fmap)
505514
{
515+
fmap[DPNPFuncName::DPNP_FN_DIAG_INDICES][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_diag_indices_c<int>};
516+
fmap[DPNPFuncName::DPNP_FN_DIAG_INDICES][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_diag_indices_c<long>};
517+
fmap[DPNPFuncName::DPNP_FN_DIAG_INDICES][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_diag_indices_c<float>};
518+
fmap[DPNPFuncName::DPNP_FN_DIAG_INDICES][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_diag_indices_c<double>};
519+
506520
fmap[DPNPFuncName::DPNP_FN_DIAGONAL][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_diagonal_c<int>};
507521
fmap[DPNPFuncName::DPNP_FN_DIAGONAL][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_diagonal_c<long>};
508522
fmap[DPNPFuncName::DPNP_FN_DIAGONAL][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_diagonal_c<float>};

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
7474
DPNP_FN_DEGREES
7575
DPNP_FN_DET
7676
DPNP_FN_DIAG
77+
DPNP_FN_DIAG_INDICES
7778
DPNP_FN_DIAGONAL
7879
DPNP_FN_DIVIDE
7980
DPNP_FN_DOT

dpnp/dpnp_algo/dpnp_algo_indexing.pyx

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ __all__ += [
5454
"dpnp_triu_indices_from"
5555
]
5656

57-
57+
ctypedef void(*fptr_dpnp_diag_indices)(void*, size_t)
5858
ctypedef void(*custom_indexing_2in_1out_func_ptr_t)(void *, const size_t, void * , void * , size_t)
5959
ctypedef void(*custom_indexing_2in_1out_func_ptr_t_)(void * , const size_t, void * , const size_t, size_t * , size_t * , const size_t)
6060
ctypedef void(*custom_indexing_2in_func_ptr_t)(void *, void * , size_t * , const size_t)
@@ -75,15 +75,25 @@ cpdef utils.dpnp_descriptor dpnp_choose(object input, list choices):
7575

7676

7777
cpdef tuple dpnp_diag_indices(n, ndim):
78-
res_item = dpnp.arange(n, dtype=dpnp.int64)
78+
cdef size_t res_size = 0 if n < 0 else n
7979

80-
# yes, all are the same item
81-
result = []
80+
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(dpnp.int64)
81+
82+
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_DIAG_INDICES, param1_type, param1_type)
83+
84+
cdef fptr_dpnp_diag_indices func = <fptr_dpnp_diag_indices> kernel_data.ptr
85+
86+
res_list = []
87+
cdef utils.dpnp_descriptor res_arr
88+
cdef shape_type_c result_shape = utils._object_to_tuple(res_size)
8289
for i in range(ndim):
83-
result.append(res_item)
90+
res_arr = utils.create_output_descriptor(result_shape, kernel_data.return_type, None)
8491

85-
return tuple(result)
92+
func(res_arr.get_data(), res_size)
8693

94+
res_list.append(res_arr.get_pyobj())
95+
96+
return tuple(res_list)
8797

8898
cpdef utils.dpnp_descriptor dpnp_diagonal(dpnp_descriptor input, offset=0):
8999
cdef shape_type_c input_shape = input.shape

0 commit comments

Comments
 (0)