Skip to content

Commit 6fe2fe4

Browse files
Add implementation of dpnp.linalg.lu_factor()
1 parent 8a17f67 commit 6fe2fe4

File tree

2 files changed

+65
-1
lines changed

2 files changed

+65
-1
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
dpnp_eigh,
5757
dpnp_inv,
5858
dpnp_lstsq,
59+
dpnp_lu_factor,
5960
dpnp_matrix_power,
6061
dpnp_matrix_rank,
6162
dpnp_multi_dot,
@@ -79,6 +80,7 @@
7980
"eigvalsh",
8081
"inv",
8182
"lstsq",
83+
"lu_factor",
8284
"matmul",
8385
"matrix_norm",
8486
"matrix_power",
@@ -901,6 +903,68 @@ def lstsq(a, b, rcond=None):
901903
return dpnp_lstsq(a, b, rcond=rcond)
902904

903905

906+
def lu_factor(a, overwrite_a=False, check_finite=True):
907+
"""
908+
Compute the pivoted LU decomposition of a matrix.
909+
910+
The decomposition is::
911+
912+
A = P @ L @ U
913+
914+
where `P` is a permutation matrix, `L` is lower triangular with unit
915+
diagonal elements, and `U` is upper triangular.
916+
917+
Parameters
918+
----------
919+
a : (M, N) {dpnp.ndarray, usm_ndarray}
920+
Input array to decompose.
921+
overwrite_a : {None, bool}, optional
922+
Whether to overwrite data in `a` (may increase performance)
923+
Default: ``False``.
924+
check_finite : {None, bool}, optional
925+
Whether to check that the input matrix contains only finite numbers.
926+
Disabling may give a performance gain, but may result in problems
927+
(crashes, non-termination) if the inputs do contain infinities or NaNs.
928+
929+
Returns
930+
-------
931+
lu :(M, N) dpnp.ndarray
932+
Matrix containing U in its upper triangle, and L in its lower triangle.
933+
The unit diagonal elements of L are not stored.
934+
piv (K, ): dpnp.ndarray
935+
Pivot indices representing the permutation matrix P:
936+
row i of matrix was interchanged with row piv[i].
937+
``K = min(M, N)``.
938+
939+
Warning
940+
-------
941+
This function synchronizes in order to validate array elements
942+
when ``check_finite=True``.
943+
944+
Limitations
945+
-----------
946+
Only two-dimensional input matrices are supported.
947+
Otherwise, the function raises ``NotImplementedError`` exception.
948+
949+
Examples
950+
--------
951+
>>> import dpnp as np
952+
>>> a = np.array([[4., 3.], [6., 3.]])
953+
>>> lu, piv = np.linalg.lu_factor(a)
954+
>>> lu
955+
array([[6. , 3. ],
956+
[0.66666667, 1. ]])
957+
>>> piv
958+
array([1, 1])
959+
960+
"""
961+
962+
dpnp.check_supported_arrays_type(a)
963+
assert_stacked_2d(a)
964+
965+
return dpnp_lu_factor(a, overwrite_a=overwrite_a, check_finite=check_finite)
966+
967+
904968
def matmul(x1, x2, /):
905969
"""
906970
Computes the matrix product.

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2307,7 +2307,7 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
23072307
# accommodate empty arrays
23082308
if a.size == 0:
23092309
lu = dpnp.empty_like(a)
2310-
piv = dpnp.arange(0, dtype=dpnp.int32)
2310+
piv = dpnp.arange(0, dtype=dpnp.int64)
23112311
return lu, piv
23122312

23132313
if check_finite:

0 commit comments

Comments
 (0)