Skip to content

Commit e141839

Browse files
Implement dpnp.linalg.lu_factor 2D inputs (#2557)
This PR suggests adding `dpnp.linalg.lu_factor()` for 2D arrays similar to `scipy.linalg.lu_factor` Support for ND inputs will be added in the next phase. In addition, this PR includes: 1. An updated implementation of `getrf` to support non-square matrices. 2. Refactoring of `_lu_factor()` by splitting the logic into separate functions to improve readability and maintainability.
1 parent 5307607 commit e141839

File tree

8 files changed

+511
-131
lines changed

8 files changed

+511
-131
lines changed

dpnp/backend/extensions/lapack/getrf.cpp

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ namespace py = pybind11;
4444
namespace type_utils = dpctl::tensor::type_utils;
4545

4646
typedef sycl::event (*getrf_impl_fn_ptr_t)(sycl::queue &,
47+
const std::int64_t,
4748
const std::int64_t,
4849
char *,
4950
std::int64_t,
@@ -56,6 +57,7 @@ static getrf_impl_fn_ptr_t getrf_dispatch_vector[dpctl_td_ns::num_types];
5657

5758
template <typename T>
5859
static sycl::event getrf_impl(sycl::queue &exec_q,
60+
const std::int64_t m,
5961
const std::int64_t n,
6062
char *in_a,
6163
std::int64_t lda,
@@ -69,7 +71,7 @@ static sycl::event getrf_impl(sycl::queue &exec_q,
6971
T *a = reinterpret_cast<T *>(in_a);
7072

7173
const std::int64_t scratchpad_size =
72-
mkl_lapack::getrf_scratchpad_size<T>(exec_q, n, n, lda);
74+
mkl_lapack::getrf_scratchpad_size<T>(exec_q, m, n, lda);
7375
T *scratchpad = nullptr;
7476

7577
std::stringstream error_msg;
@@ -82,13 +84,13 @@ static sycl::event getrf_impl(sycl::queue &exec_q,
8284

8385
getrf_event = mkl_lapack::getrf(
8486
exec_q,
85-
n, // The order of the square matrix A (0 ≤ n).
87+
m, // The number of rows in the input matrix A (0 ≤ m).
8688
// It must be a non-negative integer.
87-
n, // The number of columns in the square matrix A (0 ≤ n).
89+
n, // The number of columns in the input matrix A (0 ≤ n).
8890
// It must be a non-negative integer.
89-
a, // Pointer to the square matrix A (n x n).
91+
a, // Pointer to the input matrix A (m x n).
9092
lda, // The leading dimension of matrix A.
91-
// It must be at least max(1, n).
93+
// It must be at least max(1, m).
9294
ipiv, // Pointer to the output array of pivot indices.
9395
scratchpad, // Pointer to scratchpad memory to be used by MKL
9496
// routine for storing intermediate results.
@@ -99,7 +101,7 @@ static sycl::event getrf_impl(sycl::queue &exec_q,
99101

100102
if (info < 0) {
101103
error_msg << "Parameter number " << -info
102-
<< " had an illegal value.";
104+
<< " had an illegal value";
103105
}
104106
else if (info == scratchpad_size && e.detail() != 0) {
105107
error_msg
@@ -168,13 +170,13 @@ std::pair<sycl::event, sycl::event>
168170
if (a_array_nd != 2) {
169171
throw py::value_error(
170172
"The input array has ndim=" + std::to_string(a_array_nd) +
171-
", but a 2-dimensional array is expected.");
173+
", but a 2-dimensional array is expected");
172174
}
173175

174176
if (ipiv_array_nd != 1) {
175177
throw py::value_error("The array of pivot indices has ndim=" +
176178
std::to_string(ipiv_array_nd) +
177-
", but a 1-dimensional array is expected.");
179+
", but a 1-dimensional array is expected");
178180
}
179181

180182
// check compatibility of execution queue and allocation queue
@@ -190,10 +192,11 @@ std::pair<sycl::event, sycl::event>
190192
}
191193

192194
bool is_a_array_c_contig = a_array.is_c_contiguous();
195+
bool is_a_array_f_contig = a_array.is_f_contiguous();
193196
bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous();
194-
if (!is_a_array_c_contig) {
197+
if (!is_a_array_c_contig && !is_a_array_f_contig) {
195198
throw py::value_error("The input array "
196-
"must be C-contiguous");
199+
"must be contiguous");
197200
}
198201
if (!is_ipiv_array_c_contig) {
199202
throw py::value_error("The array of pivot indices "
@@ -208,27 +211,33 @@ std::pair<sycl::event, sycl::event>
208211
if (getrf_fn == nullptr) {
209212
throw py::value_error(
210213
"No getrf implementation defined for the provided type "
211-
"of the input matrix.");
214+
"of the input matrix");
212215
}
213216

214217
auto ipiv_types = dpctl_td_ns::usm_ndarray_types();
215218
int ipiv_array_type_id =
216219
ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum());
217220

218221
if (ipiv_array_type_id != static_cast<int>(dpctl_td_ns::typenum_t::INT64)) {
219-
throw py::value_error("The type of 'ipiv_array' must be int64.");
222+
throw py::value_error("The type of 'ipiv_array' must be int64");
220223
}
221224

222-
const std::int64_t n = a_array.get_shape_raw()[0];
225+
const py::ssize_t *a_array_shape = a_array.get_shape_raw();
226+
const std::int64_t m = a_array_shape[0];
227+
const std::int64_t n = a_array_shape[1];
228+
const std::int64_t lda = std::max<size_t>(1UL, m);
229+
230+
if (ipiv_array.get_size() != std::min(m, n)) {
231+
throw py::value_error("The size of 'ipiv_array' must be min(m, n)");
232+
}
223233

224234
char *a_array_data = a_array.get_data();
225-
const std::int64_t lda = std::max<size_t>(1UL, n);
226235

227236
char *ipiv_array_data = ipiv_array.get_data();
228237
std::int64_t *d_ipiv = reinterpret_cast<std::int64_t *>(ipiv_array_data);
229238

230239
std::vector<sycl::event> host_task_events;
231-
sycl::event getrf_ev = getrf_fn(exec_q, n, a_array_data, lda, d_ipiv,
240+
sycl::event getrf_ev = getrf_fn(exec_q, m, n, a_array_data, lda, d_ipiv,
232241
dev_info, host_task_events, depends);
233242

234243
sycl::event args_ev = dpctl::utils::keep_args_alive(

dpnp/backend/extensions/lapack/getrf_batch.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,11 @@ std::pair<sycl::event, sycl::event>
221221
}
222222

223223
bool is_a_array_c_contig = a_array.is_c_contiguous();
224+
bool is_a_array_f_contig = a_array.is_f_contiguous();
224225
bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous();
225-
if (!is_a_array_c_contig) {
226+
if (!is_a_array_c_contig && !is_a_array_f_contig) {
226227
throw py::value_error("The input array "
227-
"must be C-contiguous");
228+
"must be must contiguous");
228229
}
229230
if (!is_ipiv_array_c_contig) {
230231
throw py::value_error("The array of pivot indices "

dpnp/backend/extensions/lapack/lapack_py.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ PYBIND11_MODULE(_lapack_impl, m)
135135

136136
m.def("_getrf", &lapack_ext::getrf,
137137
"Call `getrf` from OneMKL LAPACK library to return "
138-
"the LU factorization of a general n x n matrix",
138+
"the LU factorization of a general m x n matrix",
139139
py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"),
140140
py::arg("dev_info"), py::arg("depends") = py::list());
141141

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
dpnp_eigh,
5757
dpnp_inv,
5858
dpnp_lstsq,
59+
dpnp_lu_factor,
5960
dpnp_matrix_power,
6061
dpnp_matrix_rank,
6162
dpnp_multi_dot,
@@ -79,6 +80,7 @@
7980
"eigvalsh",
8081
"inv",
8182
"lstsq",
83+
"lu_factor",
8284
"matmul",
8385
"matrix_norm",
8486
"matrix_power",
@@ -901,6 +903,68 @@ def lstsq(a, b, rcond=None):
901903
return dpnp_lstsq(a, b, rcond=rcond)
902904

903905

906+
def lu_factor(a, overwrite_a=False, check_finite=True):
907+
"""
908+
Compute the pivoted LU decomposition of a matrix.
909+
910+
The decomposition is::
911+
912+
A = P @ L @ U
913+
914+
where `P` is a permutation matrix, `L` is lower triangular with unit
915+
diagonal elements, and `U` is upper triangular.
916+
917+
Parameters
918+
----------
919+
a : (M, N) {dpnp.ndarray, usm_ndarray}
920+
Input array to decompose.
921+
overwrite_a : {None, bool}, optional
922+
Whether to overwrite data in `a` (may increase performance)
923+
Default: ``False``.
924+
check_finite : {None, bool}, optional
925+
Whether to check that the input matrix contains only finite numbers.
926+
Disabling may give a performance gain, but may result in problems
927+
(crashes, non-termination) if the inputs do contain infinities or NaNs.
928+
929+
Returns
930+
-------
931+
lu :(M, N) dpnp.ndarray
932+
Matrix containing U in its upper triangle, and L in its lower triangle.
933+
The unit diagonal elements of L are not stored.
934+
piv (K, ): dpnp.ndarray
935+
Pivot indices representing the permutation matrix P:
936+
row i of matrix was interchanged with row piv[i].
937+
``K = min(M, N)``.
938+
939+
Warning
940+
-------
941+
This function synchronizes in order to validate array elements
942+
when ``check_finite=True``.
943+
944+
Limitations
945+
-----------
946+
Only two-dimensional input matrices are supported.
947+
Otherwise, the function raises ``NotImplementedError`` exception.
948+
949+
Examples
950+
--------
951+
>>> import dpnp as np
952+
>>> a = np.array([[4., 3.], [6., 3.]])
953+
>>> lu, piv = np.linalg.lu_factor(a)
954+
>>> lu
955+
array([[6. , 3. ],
956+
[0.66666667, 1. ]])
957+
>>> piv
958+
array([1, 1])
959+
960+
"""
961+
962+
dpnp.check_supported_arrays_type(a)
963+
assert_stacked_2d(a)
964+
965+
return dpnp_lu_factor(a, overwrite_a=overwrite_a, check_finite=check_finite)
966+
967+
904968
def matmul(x1, x2, /):
905969
"""
906970
Computes the matrix product.

0 commit comments

Comments
 (0)