|
56 | 56 | dpnp_eigh, |
57 | 57 | dpnp_inv, |
58 | 58 | dpnp_lstsq, |
| 59 | + dpnp_lu_factor, |
59 | 60 | dpnp_matrix_power, |
60 | 61 | dpnp_matrix_rank, |
61 | 62 | dpnp_multi_dot, |
|
79 | 80 | "eigvalsh", |
80 | 81 | "inv", |
81 | 82 | "lstsq", |
| 83 | + "lu_factor", |
82 | 84 | "matmul", |
83 | 85 | "matrix_norm", |
84 | 86 | "matrix_power", |
@@ -901,6 +903,68 @@ def lstsq(a, b, rcond=None): |
901 | 903 | return dpnp_lstsq(a, b, rcond=rcond) |
902 | 904 |
|
903 | 905 |
|
| 906 | +def lu_factor(a, overwrite_a=False, check_finite=True): |
| 907 | + """ |
| 908 | + Compute the pivoted LU decomposition of a matrix. |
| 909 | +
|
| 910 | + The decomposition is:: |
| 911 | +
|
| 912 | + A = P @ L @ U |
| 913 | +
|
| 914 | + where `P` is a permutation matrix, `L` is lower triangular with unit |
| 915 | + diagonal elements, and `U` is upper triangular. |
| 916 | +
|
| 917 | + Parameters |
| 918 | + ---------- |
| 919 | + a : (M, N) {dpnp.ndarray, usm_ndarray} |
| 920 | + Input array to decompose. |
| 921 | + overwrite_a : {None, bool}, optional |
| 922 | + Whether to overwrite data in `a` (may increase performance) |
| 923 | + Default: ``False``. |
| 924 | + check_finite : {None, bool}, optional |
| 925 | + Whether to check that the input matrix contains only finite numbers. |
| 926 | + Disabling may give a performance gain, but may result in problems |
| 927 | + (crashes, non-termination) if the inputs do contain infinities or NaNs. |
| 928 | +
|
| 929 | + Returns |
| 930 | + ------- |
| 931 | + lu :(M, N) dpnp.ndarray |
| 932 | + Matrix containing U in its upper triangle, and L in its lower triangle. |
| 933 | + The unit diagonal elements of L are not stored. |
| 934 | + piv (K, ): dpnp.ndarray |
| 935 | + Pivot indices representing the permutation matrix P: |
| 936 | + row i of matrix was interchanged with row piv[i]. |
| 937 | + ``K = min(M, N)``. |
| 938 | +
|
| 939 | + Warning |
| 940 | + ------- |
| 941 | + This function synchronizes in order to validate array elements |
| 942 | + when ``check_finite=True``. |
| 943 | +
|
| 944 | + Limitations |
| 945 | + ----------- |
| 946 | + Only two-dimensional input matrices are supported. |
| 947 | + Otherwise, the function raises ``NotImplementedError`` exception. |
| 948 | +
|
| 949 | + Examples |
| 950 | + -------- |
| 951 | + >>> import dpnp as np |
| 952 | + >>> a = np.array([[4., 3.], [6., 3.]]) |
| 953 | + >>> lu, piv = np.linalg.lu_factor(a) |
| 954 | + >>> lu |
| 955 | + array([[6. , 3. ], |
| 956 | + [0.66666667, 1. ]]) |
| 957 | + >>> piv |
| 958 | + array([1, 1]) |
| 959 | +
|
| 960 | + """ |
| 961 | + |
| 962 | + dpnp.check_supported_arrays_type(a) |
| 963 | + assert_stacked_2d(a) |
| 964 | + |
| 965 | + return dpnp_lu_factor(a, overwrite_a=overwrite_a, check_finite=check_finite) |
| 966 | + |
| 967 | + |
904 | 968 | def matmul(x1, x2, /): |
905 | 969 | """ |
906 | 970 | Computes the matrix product. |
|
0 commit comments