Skip to content

Commit b988e9c

Browse files
authored
Add implementation of dpnp.trace (#1842)
1 parent d9fb840 commit b988e9c

File tree

12 files changed

+219
-354
lines changed

12 files changed

+219
-354
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,6 @@ enum class DPNPFuncName : size_t
115115
DPNP_FN_DIAG_INDICES_EXT, /**< Used in numpy.diag_indices() impl, requires
116116
extra parameters */
117117
DPNP_FN_DIAGONAL, /**< Used in numpy.diagonal() impl */
118-
DPNP_FN_DIAGONAL_EXT, /**< Used in numpy.diagonal() impl, requires extra
119-
parameters */
120118
DPNP_FN_DIVIDE, /**< Used in numpy.divide() impl */
121119
DPNP_FN_DOT, /**< Used in numpy.dot() impl */
122120
DPNP_FN_DOT_EXT, /**< Used in numpy.dot() impl, requires extra parameters */
@@ -343,8 +341,6 @@ enum class DPNPFuncName : size_t
343341
DPNP_FN_TANH, /**< Used in numpy.tanh() impl */
344342
DPNP_FN_TRANSPOSE, /**< Used in numpy.transpose() impl */
345343
DPNP_FN_TRACE, /**< Used in numpy.trace() impl */
346-
DPNP_FN_TRACE_EXT, /**< Used in numpy.trace() impl, requires extra
347-
parameters */
348344
DPNP_FN_TRAPZ, /**< Used in numpy.trapz() impl */
349345
DPNP_FN_TRAPZ_EXT, /**< Used in numpy.trapz() impl, requires extra
350346
parameters */

dpnp/backend/kernels/dpnp_krnl_arraycreation.cpp

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -723,15 +723,6 @@ void (*dpnp_trace_default_c)(const void *,
723723
const size_t) =
724724
dpnp_trace_c<_DataType, _ResultType>;
725725

726-
template <typename _DataType, typename _ResultType>
727-
DPCTLSyclEventRef (*dpnp_trace_ext_c)(DPCTLSyclQueueRef,
728-
const void *,
729-
void *,
730-
const shape_elem_type *,
731-
const size_t,
732-
const DPCTLEventVectorRef) =
733-
dpnp_trace_c<_DataType, _ResultType>;
734-
735726
template <typename _DataType>
736727
class dpnp_tri_c_kernel;
737728

@@ -1288,39 +1279,6 @@ void func_map_init_arraycreation(func_map_t &fmap)
12881279
fmap[DPNPFuncName::DPNP_FN_TRACE][eft_DBL][eft_DBL] = {
12891280
eft_DBL, (void *)dpnp_trace_default_c<double, double>};
12901281

1291-
fmap[DPNPFuncName::DPNP_FN_TRACE_EXT][eft_INT][eft_INT] = {
1292-
eft_INT, (void *)dpnp_trace_ext_c<int32_t, int32_t>};
1293-
fmap[DPNPFuncName::DPNP_FN_TRACE_EXT][eft_LNG][eft_INT] = {
1294-
eft_INT, (void *)dpnp_trace_ext_c<int64_t, int32_t>};
1295-
fmap[DPNPFuncName::DPNP_FN_TRACE_EXT][eft_FLT][eft_INT] = {
1296-
eft_INT, (void *)dpnp_trace_ext_c<float, int32_t>};
1297-
fmap[DPNPFuncName::DPNP_FN_TRACE_EXT][eft_DBL][eft_INT] = {
1298-
eft_INT, (void *)dpnp_trace_ext_c<double, int32_t>};
1299-
fmap[DPNPFuncName::DPNP_FN_TRACE_EXT][eft_INT][eft_LNG] = {
1300-
eft_LNG, (void *)dpnp_trace_ext_c<int32_t, int64_t>};
1301-
fmap[DPNPFuncName::DPNP_FN_TRACE_EXT][eft_LNG][eft_LNG] = {
1302-
eft_LNG, (void *)dpnp_trace_ext_c<int64_t, int64_t>};
1303-
fmap[DPNPFuncName::DPNP_FN_TRACE_EXT][eft_FLT][eft_LNG] = {
1304-
eft_LNG, (void *)dpnp_trace_ext_c<float, int64_t>};
1305-
fmap[DPNPFuncName::DPNP_FN_TRACE_EXT][eft_DBL][eft_LNG] = {
1306-
eft_LNG, (void *)dpnp_trace_ext_c<double, int64_t>};
1307-
fmap[DPNPFuncName::DPNP_FN_TRACE_EXT][eft_INT][eft_FLT] = {
1308-
eft_FLT, (void *)dpnp_trace_ext_c<int32_t, float>};
1309-
fmap[DPNPFuncName::DPNP_FN_TRACE_EXT][eft_LNG][eft_FLT] = {
1310-
eft_FLT, (void *)dpnp_trace_ext_c<int64_t, float>};
1311-
fmap[DPNPFuncName::DPNP_FN_TRACE_EXT][eft_FLT][eft_FLT] = {
1312-
eft_FLT, (void *)dpnp_trace_ext_c<float, float>};
1313-
fmap[DPNPFuncName::DPNP_FN_TRACE_EXT][eft_DBL][eft_FLT] = {
1314-
eft_FLT, (void *)dpnp_trace_ext_c<double, float>};
1315-
fmap[DPNPFuncName::DPNP_FN_TRACE_EXT][eft_INT][eft_DBL] = {
1316-
eft_DBL, (void *)dpnp_trace_ext_c<int32_t, double>};
1317-
fmap[DPNPFuncName::DPNP_FN_TRACE_EXT][eft_LNG][eft_DBL] = {
1318-
eft_DBL, (void *)dpnp_trace_ext_c<int64_t, double>};
1319-
fmap[DPNPFuncName::DPNP_FN_TRACE_EXT][eft_FLT][eft_DBL] = {
1320-
eft_DBL, (void *)dpnp_trace_ext_c<float, double>};
1321-
fmap[DPNPFuncName::DPNP_FN_TRACE_EXT][eft_DBL][eft_DBL] = {
1322-
eft_DBL, (void *)dpnp_trace_ext_c<double, double>};
1323-
13241282
fmap[DPNPFuncName::DPNP_FN_TRI][eft_INT][eft_INT] = {
13251283
eft_INT, (void *)dpnp_tri_default_c<int32_t>};
13261284
fmap[DPNPFuncName::DPNP_FN_TRI][eft_LNG][eft_LNG] = {

dpnp/backend/kernels/dpnp_krnl_indexing.cpp

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -293,18 +293,6 @@ void (*dpnp_diagonal_default_c)(void *,
293293
shape_elem_type *,
294294
const size_t) = dpnp_diagonal_c<_DataType>;
295295

296-
template <typename _DataType>
297-
DPCTLSyclEventRef (*dpnp_diagonal_ext_c)(DPCTLSyclQueueRef,
298-
void *,
299-
const size_t,
300-
void *,
301-
const size_t,
302-
shape_elem_type *,
303-
shape_elem_type *,
304-
const size_t,
305-
const DPCTLEventVectorRef) =
306-
dpnp_diagonal_c<_DataType>;
307-
308296
template <typename _DataType>
309297
DPCTLSyclEventRef
310298
dpnp_fill_diagonal_c(DPCTLSyclQueueRef q_ref,
@@ -927,19 +915,6 @@ void func_map_init_indexing_func(func_map_t &fmap)
927915
fmap[DPNPFuncName::DPNP_FN_DIAGONAL][eft_DBL][eft_DBL] = {
928916
eft_DBL, (void *)dpnp_diagonal_default_c<double>};
929917

930-
fmap[DPNPFuncName::DPNP_FN_DIAGONAL_EXT][eft_INT][eft_INT] = {
931-
eft_INT, (void *)dpnp_diagonal_ext_c<int32_t>};
932-
fmap[DPNPFuncName::DPNP_FN_DIAGONAL_EXT][eft_LNG][eft_LNG] = {
933-
eft_LNG, (void *)dpnp_diagonal_ext_c<int64_t>};
934-
fmap[DPNPFuncName::DPNP_FN_DIAGONAL_EXT][eft_FLT][eft_FLT] = {
935-
eft_FLT, (void *)dpnp_diagonal_ext_c<float>};
936-
fmap[DPNPFuncName::DPNP_FN_DIAGONAL_EXT][eft_DBL][eft_DBL] = {
937-
eft_DBL, (void *)dpnp_diagonal_ext_c<double>};
938-
fmap[DPNPFuncName::DPNP_FN_DIAGONAL_EXT][eft_C64][eft_C64] = {
939-
eft_C64, (void *)dpnp_diagonal_ext_c<std::complex<float>>};
940-
fmap[DPNPFuncName::DPNP_FN_DIAGONAL_EXT][eft_C128][eft_C128] = {
941-
eft_C128, (void *)dpnp_diagonal_ext_c<std::complex<double>>};
942-
943918
fmap[DPNPFuncName::DPNP_FN_FILL_DIAGONAL][eft_INT][eft_INT] = {
944919
eft_INT, (void *)dpnp_fill_diagonal_default_c<int32_t>};
945920
fmap[DPNPFuncName::DPNP_FN_FILL_DIAGONAL][eft_LNG][eft_LNG] = {

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
3939
DPNP_FN_CORRELATE_EXT
4040
DPNP_FN_DEGREES_EXT
4141
DPNP_FN_DIAG_INDICES_EXT
42-
DPNP_FN_DIAGONAL_EXT
4342
DPNP_FN_EDIFF1D_EXT
4443
DPNP_FN_ERF_EXT
4544
DPNP_FN_FABS_EXT
@@ -90,7 +89,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
9089
DPNP_FN_RNG_WALD_EXT
9190
DPNP_FN_RNG_WEIBULL_EXT
9291
DPNP_FN_RNG_ZIPF_EXT
93-
DPNP_FN_TRACE_EXT
9492
DPNP_FN_TRAPZ_EXT
9593

9694
cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncType": # need this namespace for Enum import

dpnp/dpnp_algo/dpnp_algo_arraycreation.pxi

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ and the rest of the library
3737

3838
__all__ += [
3939
"dpnp_copy",
40-
"dpnp_trace",
4140
]
4241

4342

@@ -73,59 +72,7 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*custom_indexing_1out_func_ptr_t)(c_dpctl.DPC
7372
const size_t ,
7473
const int,
7574
const c_dpctl.DPCTLEventVectorRef) except +
76-
ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_trace_t)(c_dpctl.DPCTLSyclQueueRef,
77-
const void *,
78-
void * ,
79-
const shape_elem_type * ,
80-
const size_t,
81-
const c_dpctl.DPCTLEventVectorRef) except +
8275

8376

8477
cpdef utils.dpnp_descriptor dpnp_copy(utils.dpnp_descriptor x1):
8578
return call_fptr_1in_1out_strides(DPNP_FN_COPY_EXT, x1)
86-
87-
88-
cpdef utils.dpnp_descriptor dpnp_trace(utils.dpnp_descriptor arr, offset=0, axis1=0, axis2=1, dtype=None, out=None):
89-
if dtype is None:
90-
dtype_ = arr.dtype
91-
else:
92-
dtype_ = dtype
93-
94-
cdef utils.dpnp_descriptor diagonal_arr = dpnp_diagonal(arr, offset)
95-
cdef size_t diagonal_ndim = diagonal_arr.ndim
96-
cdef shape_type_c diagonal_shape = diagonal_arr.shape
97-
98-
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(arr.dtype)
99-
cdef DPNPFuncType param2_type = dpnp_dtype_to_DPNPFuncType(dtype_)
100-
101-
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TRACE_EXT, param1_type, param2_type)
102-
103-
arr_obj = arr.get_array()
104-
105-
# create result array with type given by FPTR data
106-
cdef shape_type_c result_shape = diagonal_shape[:-1]
107-
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape,
108-
kernel_data.return_type,
109-
None,
110-
device=arr_obj.sycl_device,
111-
usm_type=arr_obj.usm_type,
112-
sycl_queue=arr_obj.sycl_queue)
113-
114-
result_sycl_queue = result.get_array().sycl_queue
115-
116-
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
117-
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
118-
119-
cdef fptr_dpnp_trace_t func = <fptr_dpnp_trace_t > kernel_data.ptr
120-
121-
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
122-
diagonal_arr.get_data(),
123-
result.get_data(),
124-
diagonal_shape.data(),
125-
diagonal_ndim,
126-
NULL) # dep_events_ref
127-
128-
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
129-
c_dpctl.DPCTLEvent_Delete(event_ref)
130-
131-
return result

dpnp/dpnp_algo/dpnp_algo_indexing.pxi

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_choose_t)(c_dpctl.DPCTLSyclQueueRe
5353
ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_diag_indices)(c_dpctl.DPCTLSyclQueueRef,
5454
void * , size_t,
5555
const c_dpctl.DPCTLEventVectorRef)
56-
ctypedef c_dpctl.DPCTLSyclEventRef(*custom_indexing_2in_1out_func_ptr_t_)(c_dpctl.DPCTLSyclQueueRef,
57-
void * ,
58-
const size_t,
59-
void * ,
60-
const size_t,
61-
shape_elem_type * ,
62-
shape_elem_type *,
63-
const size_t,
64-
const c_dpctl.DPCTLEventVectorRef)
6556
ctypedef c_dpctl.DPCTLSyclEventRef(*custom_indexing_2in_func_ptr_t)(c_dpctl.DPCTLSyclQueueRef,
6657
void *, void * , shape_elem_type * , const size_t,
6758
const c_dpctl.DPCTLEventVectorRef)
@@ -146,62 +137,6 @@ cpdef tuple dpnp_diag_indices(n, ndim):
146137
return tuple(res_list)
147138

148139

149-
cpdef utils.dpnp_descriptor dpnp_diagonal(dpnp_descriptor x1, offset=0):
150-
cdef shape_type_c x1_shape = x1.shape
151-
152-
n = min(x1.shape[0], x1.shape[1])
153-
res_shape = [None] * (x1.ndim - 1)
154-
155-
if x1.ndim > 2:
156-
for i in range(x1.ndim - 2):
157-
res_shape[i] = x1.shape[i + 2]
158-
159-
if (n + offset) > x1.shape[1]:
160-
res_shape[-1] = x1.shape[1] - offset
161-
elif (n + offset) > x1.shape[0]:
162-
res_shape[-1] = x1.shape[0]
163-
else:
164-
res_shape[-1] = n + offset
165-
166-
cdef shape_type_c result_shape = res_shape
167-
res_ndim = len(res_shape)
168-
169-
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)
170-
171-
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_DIAGONAL_EXT, param1_type, param1_type)
172-
173-
x1_obj = x1.get_array()
174-
175-
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape,
176-
kernel_data.return_type,
177-
None,
178-
device=x1_obj.sycl_device,
179-
usm_type=x1_obj.usm_type,
180-
sycl_queue=x1_obj.sycl_queue)
181-
182-
result_sycl_queue = result.get_array().sycl_queue
183-
184-
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
185-
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
186-
187-
cdef custom_indexing_2in_1out_func_ptr_t_ func = <custom_indexing_2in_1out_func_ptr_t_ > kernel_data.ptr
188-
189-
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
190-
x1.get_data(),
191-
x1.size,
192-
result.get_data(),
193-
offset,
194-
x1_shape.data(),
195-
result_shape.data(),
196-
res_ndim,
197-
NULL) # dep_events_ref
198-
199-
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
200-
c_dpctl.DPCTLEvent_Delete(event_ref)
201-
202-
return result
203-
204-
205140
cpdef dpnp_fill_diagonal(dpnp_descriptor x1, val):
206141
x1_obj = x1.get_array()
207142

dpnp/dpnp_array.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1393,7 +1393,18 @@ def take(self, indices, /, *, axis=None, out=None, mode="wrap"):
13931393
# 'tofile',
13941394
# 'tolist',
13951395
# 'tostring',
1396-
# 'trace',
1396+
1397+
def trace(self, offset=0, axis1=0, axis2=1, dtype=None, out=None):
1398+
"""
1399+
Return the sum along diagonals of the array.
1400+
1401+
Refer to :obj:`dpnp.trace` for full documentation.
1402+
1403+
"""
1404+
1405+
return dpnp.trace(
1406+
self, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype, out=out
1407+
)
13971408

13981409
def transpose(self, *axes):
13991410
"""

0 commit comments

Comments
 (0)