Skip to content

Commit d34b8ce

Browse files
Update dpnp.linalg.matrix_rank() implementation (#1717)
* Update dpnp.linalg.matrix_rank impl * Add cupy tests for dpnp.linalg.matrix_rank * Add dpnp tests for dpnp.linalg.matrix_rank * Remove old impl of dpnp_matrix_rank * Address remarks * Address remarks
1 parent a100705 commit d34b8ce

File tree

10 files changed

+253
-142
lines changed

10 files changed

+253
-142
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -178,28 +178,26 @@ enum class DPNPFuncName : size_t
178178
DPNP_FN_KRON, /**< Used in numpy.kron() impl */
179179
DPNP_FN_KRON_EXT, /**< Used in numpy.kron() impl, requires extra parameters
180180
*/
181-
DPNP_FN_LEFT_SHIFT, /**< Used in numpy.left_shift() impl */
182-
DPNP_FN_LOG, /**< Used in numpy.log() impl */
183-
DPNP_FN_LOG10, /**< Used in numpy.log10() impl */
184-
DPNP_FN_LOG2, /**< Used in numpy.log2() impl */
185-
DPNP_FN_LOG1P, /**< Used in numpy.log1p() impl */
186-
DPNP_FN_MATMUL, /**< Used in numpy.matmul() impl */
187-
DPNP_FN_MATRIX_RANK, /**< Used in numpy.linalg.matrix_rank() impl */
188-
DPNP_FN_MATRIX_RANK_EXT, /**< Used in numpy.linalg.matrix_rank() impl,
189-
requires extra parameters */
190-
DPNP_FN_MAX, /**< Used in numpy.max() impl */
191-
DPNP_FN_MAXIMUM, /**< Used in numpy.fmax() impl */
192-
DPNP_FN_MAXIMUM_EXT, /**< Used in numpy.fmax() impl , requires extra
193-
parameters */
194-
DPNP_FN_MEAN, /**< Used in numpy.mean() impl */
195-
DPNP_FN_MEDIAN, /**< Used in numpy.median() impl */
196-
DPNP_FN_MEDIAN_EXT, /**< Used in numpy.median() impl, requires extra
197-
parameters */
198-
DPNP_FN_MIN, /**< Used in numpy.min() impl */
199-
DPNP_FN_MINIMUM, /**< Used in numpy.fmin() impl */
200-
DPNP_FN_MINIMUM_EXT, /**< Used in numpy.fmax() impl, requires extra
201-
parameters */
202-
DPNP_FN_MODF, /**< Used in numpy.modf() impl */
181+
DPNP_FN_LEFT_SHIFT, /**< Used in numpy.left_shift() impl */
182+
DPNP_FN_LOG, /**< Used in numpy.log() impl */
183+
DPNP_FN_LOG10, /**< Used in numpy.log10() impl */
184+
DPNP_FN_LOG2, /**< Used in numpy.log2() impl */
185+
DPNP_FN_LOG1P, /**< Used in numpy.log1p() impl */
186+
DPNP_FN_MATMUL, /**< Used in numpy.matmul() impl */
187+
DPNP_FN_MATRIX_RANK, /**< Used in numpy.linalg.matrix_rank() impl */
188+
DPNP_FN_MAX, /**< Used in numpy.max() impl */
189+
DPNP_FN_MAXIMUM, /**< Used in numpy.fmax() impl */
190+
DPNP_FN_MAXIMUM_EXT, /**< Used in numpy.fmax() impl , requires extra
191+
parameters */
192+
DPNP_FN_MEAN, /**< Used in numpy.mean() impl */
193+
DPNP_FN_MEDIAN, /**< Used in numpy.median() impl */
194+
DPNP_FN_MEDIAN_EXT, /**< Used in numpy.median() impl, requires extra
195+
parameters */
196+
DPNP_FN_MIN, /**< Used in numpy.min() impl */
197+
DPNP_FN_MINIMUM, /**< Used in numpy.fmin() impl */
198+
DPNP_FN_MINIMUM_EXT, /**< Used in numpy.fmax() impl, requires extra
199+
parameters */
200+
DPNP_FN_MODF, /**< Used in numpy.modf() impl */
203201
DPNP_FN_MODF_EXT, /**< Used in numpy.modf() impl, requires extra parameters
204202
*/
205203
DPNP_FN_MULTIPLY, /**< Used in numpy.multiply() impl */

dpnp/backend/kernels/dpnp_krnl_linalg.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -579,15 +579,6 @@ template <typename _DataType>
579579
void (*dpnp_matrix_rank_default_c)(void *, void *, shape_elem_type *, size_t) =
580580
dpnp_matrix_rank_c<_DataType>;
581581

582-
template <typename _DataType>
583-
DPCTLSyclEventRef (*dpnp_matrix_rank_ext_c)(DPCTLSyclQueueRef,
584-
void *,
585-
void *,
586-
shape_elem_type *,
587-
size_t,
588-
const DPCTLEventVectorRef) =
589-
dpnp_matrix_rank_c<_DataType>;
590-
591582
template <typename _InputDT, typename _ComputeDT>
592583
DPCTLSyclEventRef dpnp_qr_c(DPCTLSyclQueueRef q_ref,
593584
void *array1_in,
@@ -969,15 +960,6 @@ void func_map_init_linalg_func(func_map_t &fmap)
969960
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK][eft_DBL][eft_DBL] = {
970961
eft_DBL, (void *)dpnp_matrix_rank_default_c<double>};
971962

972-
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK_EXT][eft_INT][eft_INT] = {
973-
eft_INT, (void *)dpnp_matrix_rank_ext_c<int32_t>};
974-
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK_EXT][eft_LNG][eft_LNG] = {
975-
eft_LNG, (void *)dpnp_matrix_rank_ext_c<int64_t>};
976-
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK_EXT][eft_FLT][eft_FLT] = {
977-
eft_FLT, (void *)dpnp_matrix_rank_ext_c<float>};
978-
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK_EXT][eft_DBL][eft_DBL] = {
979-
eft_DBL, (void *)dpnp_matrix_rank_ext_c<double>};
980-
981963
fmap[DPNPFuncName::DPNP_FN_QR][eft_INT][eft_INT] = {
982964
eft_DBL, (void *)dpnp_qr_default_c<int32_t, double>};
983965
fmap[DPNPFuncName::DPNP_FN_QR][eft_LNG][eft_LNG] = {

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
7878
DPNP_FN_FULL_LIKE
7979
DPNP_FN_KRON
8080
DPNP_FN_KRON_EXT
81-
DPNP_FN_MATRIX_RANK
82-
DPNP_FN_MATRIX_RANK_EXT
8381
DPNP_FN_MAXIMUM
8482
DPNP_FN_MAXIMUM_EXT
8583
DPNP_FN_MEDIAN

dpnp/linalg/dpnp_algo_linalg.pyx

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -48,24 +48,14 @@ __all__ = [
4848
"dpnp_cond",
4949
"dpnp_eig",
5050
"dpnp_eigvals",
51-
"dpnp_matrix_rank",
5251
"dpnp_norm",
5352
]
5453

5554

5655
# C function pointer to the C library template functions
57-
ctypedef c_dpctl.DPCTLSyclEventRef(*custom_linalg_1in_1out_func_ptr_t)(c_dpctl.DPCTLSyclQueueRef,
58-
void *, void * ,shape_elem_type * ,
59-
size_t, const c_dpctl.DPCTLEventVectorRef)
60-
ctypedef c_dpctl.DPCTLSyclEventRef(*custom_linalg_1in_1out_func_ptr_t_)(c_dpctl.DPCTLSyclQueueRef,
61-
void * , void * , size_t * ,
62-
const c_dpctl.DPCTLEventVectorRef)
6356
ctypedef c_dpctl.DPCTLSyclEventRef(*custom_linalg_1in_1out_with_size_func_ptr_t_)(c_dpctl.DPCTLSyclQueueRef,
6457
void *, void * , size_t,
6558
const c_dpctl.DPCTLEventVectorRef)
66-
ctypedef c_dpctl.DPCTLSyclEventRef(*custom_linalg_1in_3out_shape_t)(c_dpctl.DPCTLSyclQueueRef,
67-
void *, void * , void * , void * ,
68-
size_t , size_t, const c_dpctl.DPCTLEventVectorRef)
6959
ctypedef c_dpctl.DPCTLSyclEventRef(*custom_linalg_2in_1out_func_ptr_t)(c_dpctl.DPCTLSyclQueueRef,
7060
void *, void * , void * , size_t,
7161
const c_dpctl.DPCTLEventVectorRef)
@@ -183,42 +173,6 @@ cpdef utils.dpnp_descriptor dpnp_eigvals(utils.dpnp_descriptor input):
183173
return res_val
184174

185175

186-
cpdef utils.dpnp_descriptor dpnp_matrix_rank(utils.dpnp_descriptor input):
187-
cdef shape_type_c input_shape = input.shape
188-
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(input.dtype)
189-
190-
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_MATRIX_RANK_EXT, param1_type, param1_type)
191-
192-
input_obj = input.get_array()
193-
194-
# create result array with type given by FPTR data
195-
cdef utils.dpnp_descriptor result = utils.create_output_descriptor((1,),
196-
kernel_data.return_type,
197-
None,
198-
device=input_obj.sycl_device,
199-
usm_type=input_obj.usm_type,
200-
sycl_queue=input_obj.sycl_queue)
201-
202-
result_sycl_queue = result.get_array().sycl_queue
203-
204-
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
205-
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
206-
207-
cdef custom_linalg_1in_1out_func_ptr_t func = <custom_linalg_1in_1out_func_ptr_t > kernel_data.ptr
208-
209-
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
210-
input.get_data(),
211-
result.get_data(),
212-
input_shape.data(),
213-
input.ndim,
214-
NULL) # dep_events_ref
215-
216-
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
217-
c_dpctl.DPCTLEvent_Delete(event_ref)
218-
219-
return result
220-
221-
222176
cpdef object dpnp_norm(object input, ord=None, axis=None):
223177
cdef long size_input = input.size
224178
cdef shape_type_c shape_input = input.shape

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
dpnp_det,
5252
dpnp_eigh,
5353
dpnp_inv,
54+
dpnp_matrix_rank,
5455
dpnp_pinv,
5556
dpnp_qr,
5657
dpnp_slogdet,
@@ -397,47 +398,57 @@ def matrix_power(input, count):
397398
return call_origin(numpy.linalg.matrix_power, input, count)
398399

399400

400-
def matrix_rank(input, tol=None, hermitian=False):
401+
def matrix_rank(A, tol=None, hermitian=False):
401402
"""
402-
Return matrix rank of array.
403+
Return matrix rank of array using SVD method.
403404
404405
Rank of the array is the number of singular values of the array that are
405406
greater than `tol`.
406407
407408
Parameters
408409
----------
409-
M : {(M,), (..., M, N)} array_like
410+
A : {(M,), (..., M, N)} {dpnp.ndarray, usm_ndarray}
410411
Input vector or stack of matrices.
411-
tol : (...) array_like, float, optional
412+
tol : (...) {float, dpnp.ndarray, usm_ndarray}, optional
412413
Threshold below which SVD values are considered zero. If `tol` is
413414
None, and ``S`` is an array with singular values for `M`, and
414415
``eps`` is the epsilon value for datatype of ``S``, then `tol` is
415416
set to ``S.max() * max(M.shape) * eps``.
416417
hermitian : bool, optional
417-
If True, `M` is assumed to be Hermitian (symmetric if real-valued),
418+
If True, `A` is assumed to be Hermitian (symmetric if real-valued),
418419
enabling a more efficient method for finding singular values.
419420
Defaults to False.
420421
421422
Returns
422423
-------
423-
rank : (...) array_like
424-
Rank of M.
424+
rank : (...) dpnp.ndarray
425+
Rank of A.
425426
426-
"""
427+
See Also
428+
--------
429+
:obj:`dpnp.linalg.svd` : Singular Value Decomposition.
427430
428-
x1_desc = dpnp.get_dpnp_descriptor(input, copy_when_nondefault_queue=False)
429-
if x1_desc:
430-
if tol is not None:
431-
pass
432-
elif hermitian:
433-
pass
434-
else:
435-
result_obj = dpnp_matrix_rank(x1_desc).get_pyobj()
436-
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
431+
Examples
432+
--------
433+
>>> import dpnp as np
434+
>>> from dpnp.linalg import matrix_rank
435+
>>> matrix_rank(np.eye(4)) # Full rank matrix
436+
array(4)
437+
>>> I=np.eye(4); I[-1,-1] = 0. # rank deficient matrix
438+
>>> matrix_rank(I)
439+
array(3)
440+
>>> matrix_rank(np.ones((4,))) # 1 dimension - rank 1 unless all 0
441+
array(1)
442+
>>> matrix_rank(np.zeros((4,)))
443+
array(0)
437444
438-
return result
445+
"""
446+
447+
dpnp.check_supported_arrays_type(A)
448+
if tol is not None:
449+
dpnp.check_supported_arrays_type(tol, scalar_type=True)
439450

440-
return call_origin(numpy.linalg.matrix_rank, input, tol, hermitian)
451+
return dpnp_matrix_rank(A, tol=tol, hermitian=hermitian)
441452

442453

443454
def multi_dot(arrays, out=None):

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"dpnp_det",
4040
"dpnp_eigh",
4141
"dpnp_inv",
42+
"dpnp_matrix_rank",
4243
"dpnp_pinv",
4344
"dpnp_qr",
4445
"dpnp_slogdet",
@@ -999,6 +1000,29 @@ def dpnp_inv(a):
9991000
return b_f
10001001

10011002

1003+
def dpnp_matrix_rank(A, tol=None, hermitian=False):
1004+
"""
1005+
dpnp_matrix_rank(A, tol=None, hermitian=False)
1006+
1007+
Return matrix rank of array using SVD method.
1008+
1009+
"""
1010+
1011+
if A.ndim < 2:
1012+
return (A != 0).any().astype(int)
1013+
1014+
S = dpnp_svd(A, compute_uv=False, hermitian=hermitian)
1015+
1016+
if tol is None:
1017+
rtol = max(A.shape[-2:]) * dpnp.finfo(S.dtype).eps
1018+
tol = S.max(axis=-1, keepdims=True) * rtol
1019+
elif not dpnp.isscalar(tol):
1020+
# Add a new axis to match Numpy's output
1021+
tol = tol[..., None]
1022+
1023+
return dpnp.count_nonzero(S > tol, axis=-1)
1024+
1025+
10021026
def dpnp_pinv(a, rcond=1e-15, hermitian=False):
10031027
"""
10041028
dpnp_pinv(a, rcond=1e-15, hermitian=False):

0 commit comments

Comments
 (0)