diff --git a/CHANGELOG.md b/CHANGELOG.md index a2ee5771a5b..3e5bc60b80d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ This release changes the license from `BSD-2-Clause` to `BSD-3-Clause`. ### Added * Added the docstrings to `dpnp.linalg.LinAlgError` exception [#2613](https://github.com/IntelPython/dpnp/pull/2613) +* Added implementation of `dpnp.linalg.lu_solve` for batch inputs (SciPy-compatible) [#2619](https://github.com/IntelPython/dpnp/pull/2619) ### Changed diff --git a/dpnp/backend/extensions/lapack/CMakeLists.txt b/dpnp/backend/extensions/lapack/CMakeLists.txt index 6b61e092367..9c495e24359 100644 --- a/dpnp/backend/extensions/lapack/CMakeLists.txt +++ b/dpnp/backend/extensions/lapack/CMakeLists.txt @@ -41,6 +41,7 @@ set(_module_src ${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/getri_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/getrs.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/getrs_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/heevd_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/orgqr.cpp diff --git a/dpnp/backend/extensions/lapack/getrs.cpp b/dpnp/backend/extensions/lapack/getrs.cpp index f89f06ed1d5..2186620ffb9 100644 --- a/dpnp/backend/extensions/lapack/getrs.cpp +++ b/dpnp/backend/extensions/lapack/getrs.cpp @@ -51,14 +51,14 @@ namespace type_utils = dpctl::tensor::type_utils; using ext::common::init_dispatch_vector; typedef sycl::event (*getrs_impl_fn_ptr_t)(sycl::queue &, - oneapi::mkl::transpose, + const oneapi::mkl::transpose, const std::int64_t, const std::int64_t, char *, - std::int64_t, - std::int64_t *, + const std::int64_t, + const std::int64_t *, char *, - std::int64_t, + const std::int64_t, std::vector &, const std::vector &); @@ -70,10 +70,10 @@ static sycl::event getrs_impl(sycl::queue &exec_q, const std::int64_t n, const std::int64_t nrhs, char *in_a, - std::int64_t lda, - std::int64_t *ipiv, + const std::int64_t lda, + const std::int64_t *ipiv, char *in_b, - std::int64_t ldb, + const std::int64_t ldb, std::vector &host_task_events, const std::vector &depends) { @@ -234,7 +234,7 @@ std::pair throw py::value_error("The right-hand sides array " "must be F-contiguous"); } - if (!is_ipiv_array_c_contig || !is_ipiv_array_f_contig) { + if (!is_ipiv_array_c_contig && !is_ipiv_array_f_contig) { throw py::value_error("The array of pivot indices " "must be contiguous"); } diff --git a/dpnp/backend/extensions/lapack/getrs.hpp b/dpnp/backend/extensions/lapack/getrs.hpp index 599bac931e4..f5a47c69c9e 100644 --- a/dpnp/backend/extensions/lapack/getrs.hpp +++ b/dpnp/backend/extensions/lapack/getrs.hpp @@ -40,8 +40,23 @@ extern std::pair const dpctl::tensor::usm_ndarray &a_array, const dpctl::tensor::usm_ndarray &ipiv_array, const dpctl::tensor::usm_ndarray &b_array, - oneapi::mkl::transpose trans, + const oneapi::mkl::transpose trans, const std::vector &depends = {}); +extern std::pair + getrs_batch(sycl::queue &exec_q, + const dpctl::tensor::usm_ndarray &a_array, + const dpctl::tensor::usm_ndarray &ipiv_array, + const dpctl::tensor::usm_ndarray &b_array, + const oneapi::mkl::transpose trans, + const std::int64_t n, + const std::int64_t nrhs, + const std::int64_t stride_a, + const std::int64_t stride_ipiv, + const std::int64_t stride_b, + const std::int64_t batch_size, + const std::vector &depends = {}); + extern void init_getrs_dispatch_vector(void); +extern void init_getrs_batch_dispatch_vector(void); } // namespace dpnp::extensions::lapack diff --git a/dpnp/backend/extensions/lapack/getrs_batch.cpp b/dpnp/backend/extensions/lapack/getrs_batch.cpp new file mode 100644 index 00000000000..a526efd6c07 --- /dev/null +++ b/dpnp/backend/extensions/lapack/getrs_batch.cpp @@ -0,0 +1,353 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include +#include +#include + +#include +#include + +// utils extension header +#include "ext/common.hpp" + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/type_utils.hpp" + +#include "getrs.hpp" +#include "linalg_exceptions.hpp" +#include "types_matrix.hpp" + +namespace dpnp::extensions::lapack +{ +namespace mkl_lapack = oneapi::mkl::lapack; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +using ext::common::init_dispatch_vector; + +typedef sycl::event (*getrs_batch_impl_fn_ptr_t)( + sycl::queue &, + oneapi::mkl::transpose, // trans + const std::int64_t, // n + const std::int64_t, // nrhs + char *, // a + const std::int64_t, // lda + const std::int64_t, // stride_a + const std::int64_t *, // ipiv + const std::int64_t, // stride_ipiv + char *, // b + const std::int64_t, // ldb + const std::int64_t, // stride_b + const std::int64_t, // batch_size + std::vector &, + const std::vector &); + +static getrs_batch_impl_fn_ptr_t + getrs_batch_dispatch_vector[dpctl_td_ns::num_types]; + +template +static sycl::event getrs_batch_impl(sycl::queue &exec_q, + oneapi::mkl::transpose trans, + const std::int64_t n, + const std::int64_t nrhs, + char *in_a, + const std::int64_t lda, + const std::int64_t stride_a, + const std::int64_t *ipiv, + const std::int64_t stride_ipiv, + char *in_b, + const std::int64_t ldb, + const std::int64_t stride_b, + const std::int64_t batch_size, + std::vector &host_task_events, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + T *a = reinterpret_cast(in_a); + T *b = reinterpret_cast(in_b); + + const std::int64_t scratchpad_size = + mkl_lapack::getrs_batch_scratchpad_size(exec_q, trans, n, nrhs, lda, + stride_a, stride_ipiv, ldb, + stride_b, batch_size); + T *scratchpad = nullptr; + + std::stringstream error_msg; + std::int64_t info = 0; + bool is_exception_caught = false; + + sycl::event getrs_batch_event; + try { + scratchpad = sycl::malloc_device(scratchpad_size, exec_q); + + getrs_batch_event = mkl_lapack::getrs_batch( + exec_q, + trans, // Specifies the operation: whether or not to transpose + // matrix A. Can be 'N' for no transpose, 'T' for transpose, + // and 'C' for conjugate transpose. + n, // The order of the square matrix A + // and the number of rows in matrix B (0 ≤ n). + // It must be a non-negative integer. + nrhs, // The number of right-hand sides, + // i.e., the number of columns in matrix B (0 ≤ nrhs). + a, // Pointer to the square matrix A (n x n). + lda, // The leading dimension of matrix A, must be at least max(1, + // n). It must be at least max(1, n). + stride_a, // Stride between consecutive A matrices in the batch. + ipiv, // Pointer to the output array of pivot indices that were used + // during factorization (n, ). + stride_ipiv, // Stride between consecutive pivot arrays in the + // batch. + b, // Pointer to the matrix B of right-hand sides (ldb, nrhs). + ldb, // The leading dimension of matrix B, must be at least max(1, + // n). + stride_b, // Stride between consecutive B matrices in the batch. + batch_size, // Total number of matrices in the batch. + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, depends); + } catch (mkl_lapack::batch_error const &be) { + // Get the indices of matrices within the batch that encountered an + // error + auto error_matrices_ids = be.ids(); + + // OneMKL batched functions throw a single `batch_error` + // instead of per-matrix exceptions or an info array. + // This is interpreted as a computation_error (singular matrix), + // consistent with non-batched LAPACK behavior. + is_exception_caught = false; + if (scratchpad != nullptr) { + dpctl::tensor::alloc_utils::sycl_free_noexcept(scratchpad, exec_q); + } + throw LinAlgError("The solve could not be completed."); + } catch (mkl_lapack::exception const &e) { + is_exception_caught = true; + info = e.info(); + + if (info < 0) { + error_msg << "Parameter number " << -info + << " had an illegal value."; + } + else if (info == scratchpad_size && e.detail() != 0) { + error_msg + << "Insufficient scratchpad size. Required size is at least " + << e.detail(); + } + else if (info > 0) { + is_exception_caught = false; + if (scratchpad != nullptr) { + dpctl::tensor::alloc_utils::sycl_free_noexcept(scratchpad, + exec_q); + } + throw LinAlgError("The solve could not be completed."); + } + else { + error_msg << "Unexpected MKL exception caught during getrs() " + "call:\nreason: " + << e.what() << "\ninfo: " << e.info(); + } + } catch (sycl::exception const &e) { + is_exception_caught = true; + error_msg << "Unexpected SYCL exception caught during getrs() call:\n" + << e.what(); + } + + if (is_exception_caught) // an unexpected error occurs + { + if (scratchpad != nullptr) { + dpctl::tensor::alloc_utils::sycl_free_noexcept(scratchpad, exec_q); + } + + throw std::runtime_error(error_msg.str()); + } + + sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(getrs_batch_event); + auto ctx = exec_q.get_context(); + cgh.host_task([ctx, scratchpad]() { + dpctl::tensor::alloc_utils::sycl_free_noexcept(scratchpad, ctx); + }); + }); + host_task_events.push_back(clean_up_event); + return getrs_batch_event; +} + +std::pair + getrs_batch(sycl::queue &exec_q, + const dpctl::tensor::usm_ndarray &a_array, + const dpctl::tensor::usm_ndarray &ipiv_array, + const dpctl::tensor::usm_ndarray &b_array, + oneapi::mkl::transpose trans, + std::int64_t n, + std::int64_t nrhs, + std::int64_t stride_a, + std::int64_t stride_ipiv, + std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &depends) +{ + const int a_array_nd = a_array.get_ndim(); + const int b_array_nd = b_array.get_ndim(); + const int ipiv_array_nd = ipiv_array.get_ndim(); + + if (a_array_nd < 3) { + throw py::value_error( + "The LU-factorized array has ndim=" + std::to_string(a_array_nd) + + ", but an array with ndim >= 3 is expected"); + } + if (b_array_nd < 2) { + throw py::value_error("The right-hand sides array has ndim=" + + std::to_string(b_array_nd) + + ", but an array with ndim >= 2 is expected"); + } + if (ipiv_array_nd < 2) { + throw py::value_error("The array of pivot indices has ndim=" + + std::to_string(ipiv_array_nd) + + ", but an array with ndim >= 2 is expected"); + } + + const py::ssize_t *a_array_shape = a_array.get_shape_raw(); + if (a_array_shape[0] != a_array_shape[1]) { + throw py::value_error("Expected batch of square matrices , but got " + "matrix shape (" + + std::to_string(a_array_shape[0]) + ", " + + std::to_string(a_array_shape[1]) + ") in batch"); + } + + if (ipiv_array_nd != a_array_nd - 1) { + throw py::value_error( + "The array of pivot indices has ndim=" + + std::to_string(ipiv_array_nd) + + ", but an array with ndim=" + std::to_string(a_array_nd - 1) + + " is expected to match LU batch dimensions"); + } + + // check compatibility of execution queue and allocation queue + if (!dpctl::utils::queues_are_compatible(exec_q, + {a_array, b_array, ipiv_array})) + { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(a_array, b_array)) { + throw py::value_error("The LU-factorized and right-hand sides arrays " + "are overlapping segments of memory"); + } + + bool is_a_array_c_contig = a_array.is_c_contiguous(); + bool is_a_array_f_contig = a_array.is_f_contiguous(); + bool is_b_array_f_contig = b_array.is_f_contiguous(); + bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous(); + bool is_ipiv_array_f_contig = ipiv_array.is_f_contiguous(); + if (!is_a_array_c_contig && !is_a_array_f_contig) { + throw py::value_error("The LU-factorized array " + "must be either C-contiguous " + "or F-contiguous"); + } + if (!is_b_array_f_contig) { + throw py::value_error("The right-hand sides array " + "must be F-contiguous"); + } + if (!is_ipiv_array_c_contig && !is_ipiv_array_f_contig) { + throw py::value_error("The array of pivot indices " + "must be contiguous"); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + int a_array_type_id = + array_types.typenum_to_lookup_id(a_array.get_typenum()); + int b_array_type_id = + array_types.typenum_to_lookup_id(b_array.get_typenum()); + + if (a_array_type_id != b_array_type_id) { + throw py::value_error("The types of the LU-factorized and " + "right-hand sides arrays are mismatched"); + } + + getrs_batch_impl_fn_ptr_t getrs_batch_fn = + getrs_batch_dispatch_vector[a_array_type_id]; + if (getrs_batch_fn == nullptr) { + throw py::value_error( + "No getrs_batch implementation defined for the provided type " + "of the input matrix"); + } + + auto ipiv_types = dpctl_td_ns::usm_ndarray_types(); + int ipiv_array_type_id = + ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum()); + + if (ipiv_array_type_id != static_cast(dpctl_td_ns::typenum_t::INT64)) { + throw py::value_error("The type of 'ipiv_array' must be int64"); + } + + const std::int64_t lda = std::max(1UL, n); + const std::int64_t ldb = std::max(1UL, n); + + char *a_array_data = a_array.get_data(); + char *b_array_data = b_array.get_data(); + char *ipiv_array_data = ipiv_array.get_data(); + + std::int64_t *ipiv = reinterpret_cast(ipiv_array_data); + + std::vector host_task_events; + sycl::event getrs_batch_ev = getrs_batch_fn( + exec_q, trans, n, nrhs, a_array_data, lda, stride_a, ipiv, stride_ipiv, + b_array_data, ldb, stride_b, batch_size, host_task_events, depends); + + sycl::event args_ev = dpctl::utils::keep_args_alive( + exec_q, {a_array, b_array, ipiv_array}, host_task_events); + + return std::make_pair(args_ev, getrs_batch_ev); +} + +template +struct GetrsBatchContigFactory +{ + fnT get() + { + if constexpr (types::GetrsBatchTypePairSupportFactory::is_defined) { + return getrs_batch_impl; + } + else { + return nullptr; + } + } +}; + +void init_getrs_batch_dispatch_vector(void) +{ + init_dispatch_vector( + getrs_batch_dispatch_vector); +} +} // namespace dpnp::extensions::lapack diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index e71d8826af3..b8a711d5bd0 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -61,6 +61,7 @@ void init_dispatch_vectors(void) lapack_ext::init_getrf_batch_dispatch_vector(); lapack_ext::init_getrf_dispatch_vector(); lapack_ext::init_getri_batch_dispatch_vector(); + lapack_ext::init_getrs_batch_dispatch_vector(); lapack_ext::init_getrs_dispatch_vector(); lapack_ext::init_orgqr_batch_dispatch_vector(); lapack_ext::init_orgqr_dispatch_vector(); @@ -167,12 +168,22 @@ PYBIND11_MODULE(_lapack_impl, m) m.def("_getrs", &lapack_ext::getrs, "Call `getrs` from OneMKL LAPACK library to return " - "the solves of linear equations with an LU-factored " + "the solutions of linear equations with an LU-factored " "square coefficient matrix, with multiple right-hand sides", py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"), py::arg("b_array"), py::arg("trans") = oneapi::mkl::transpose::N, py::arg("depends") = py::list()); + m.def("_getrs_batch", &lapack_ext::getrs_batch, + "Call `getrs_batch` from OneMKL LAPACK library to return " + "the solutions of batch linear equations with an LU-factored " + "square coefficient matrix, with multiple right-hand sides", + py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"), + py::arg("b_array"), py::arg("trans") = oneapi::mkl::transpose::N, + py::arg("n"), py::arg("nrhs"), py::arg("stride_a"), + py::arg("stride_ipiv"), py::arg("stride_b"), py::arg("batch_size"), + py::arg("depends") = py::list()); + m.def("_orgqr_batch", &lapack_ext::orgqr_batch, "Call `_orgqr_batch` from OneMKL LAPACK library to return " "the real orthogonal matrix Qi of the QR factorization " diff --git a/dpnp/backend/extensions/lapack/types_matrix.hpp b/dpnp/backend/extensions/lapack/types_matrix.hpp index ab2794198cb..3d7b5e1b9df 100644 --- a/dpnp/backend/extensions/lapack/types_matrix.hpp +++ b/dpnp/backend/extensions/lapack/types_matrix.hpp @@ -248,6 +248,34 @@ struct GetrsTypePairSupportFactory dpctl_td_ns::NotDefinedEntry>::is_defined; }; +/** + * @brief A factory to define pairs of supported types for which + * MKL LAPACK library provides support in oneapi::mkl::lapack::getrs_batch + * function. + * + * @tparam T Type of array containing batched input matrix (LU-factored form) + * and the array of multiple dependent variables, + * as well as the output array for storing the solutions to a system of linear + * equations. + */ +template +struct GetrsBatchTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; + /** * @brief A factory to define pairs of supported types for which * MKL LAPACK library provides support in oneapi::mkl::lapack::heevd diff --git a/dpnp/scipy/linalg/_decomp_lu.py b/dpnp/scipy/linalg/_decomp_lu.py index 06ee82330ef..772c24b299d 100644 --- a/dpnp/scipy/linalg/_decomp_lu.py +++ b/dpnp/scipy/linalg/_decomp_lu.py @@ -42,7 +42,10 @@ import dpnp -from dpnp.linalg.dpnp_utils_linalg import assert_stacked_2d +from dpnp.linalg.dpnp_utils_linalg import ( + assert_stacked_2d, + assert_stacked_square, +) from ._utils import ( dpnp_lu_factor, @@ -132,7 +135,7 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): Parameters ---------- lu, piv : {tuple of dpnp.ndarrays or usm_ndarrays} - LU factorization of matrix `a` (M, M) together with pivot indices. + LU factorization of matrix `a` (..., M, M) together with pivot indices. b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray} Right-hand side trans : {0, 1, 2} , optional @@ -160,7 +163,7 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): Returns ------- - x : {(M,), (M, K)} dpnp.ndarray + x : {(M,), (..., M, K)} dpnp.ndarray Solution to the system Warning @@ -187,6 +190,7 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): (lu, piv) = lu_and_piv dpnp.check_supported_arrays_type(lu, piv, b) assert_stacked_2d(lu) + assert_stacked_square(lu) return dpnp_lu_solve( lu, diff --git a/dpnp/scipy/linalg/_utils.py b/dpnp/scipy/linalg/_utils.py index a60d333bc55..27a7ffa58bd 100644 --- a/dpnp/scipy/linalg/_utils.py +++ b/dpnp/scipy/linalg/_utils.py @@ -58,6 +58,37 @@ ] +def _align_lu_solve_broadcast(lu, b): + """Align LU and RHS batch dimensions with SciPy-like rules.""" + lu_shape = lu.shape + b_shape = b.shape + + if b.ndim < 2: + if lu_shape[-2] != b_shape[0]: + raise ValueError( + f"Shapes of lu {lu_shape} and b {b_shape} are incompatible" + ) + b = dpnp.broadcast_to(b, lu_shape[:-1]) + return lu, b + + if lu_shape[-2] != b_shape[-2]: + raise ValueError( + f"Shapes of lu {lu_shape} and b {b_shape} are incompatible" + ) + + # Use dpnp.broadcast_shapes() to align the resulting batch shapes + batch = dpnp.broadcast_shapes(lu_shape[:-2], b_shape[:-2]) + lu_bshape = batch + lu_shape[-2:] + b_bshape = batch + b_shape[-2:] + + if lu_shape != lu_bshape: + lu = dpnp.broadcast_to(lu, lu_bshape) + if b_shape != b_bshape: + b = dpnp.broadcast_to(b, b_bshape) + + return lu, b + + def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals """SciPy-compatible LU factorization for batched inputs.""" @@ -183,6 +214,105 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals return (a_h, ipiv_h) +def _batched_lu_solve(lu, piv, b, res_type, trans=0): + """Solve a batched equation system (SciPy-compatible behavior).""" + res_usm_type, exec_q = get_usm_allocations([lu, piv, b]) + + b_ndim = b.ndim + + lu, b = _align_lu_solve_broadcast(lu, b) + + n = lu.shape[-1] + nrhs = b.shape[-1] if b_ndim > 1 else 1 + + # get 3d input arrays by reshape + if lu.ndim > 3: + lu = dpnp.reshape(lu, (-1, n, n)) + # get 2d pivot arrays by reshape + if piv.ndim > 2: + piv = dpnp.reshape(piv, (-1, n)) + batch_size = lu.shape[0] + + # Move batch axis to the end (n, n, batch) in Fortran order: + # required by getrs_batch + # and ensures each a[..., i] is F-contiguous for getrs_batch + lu = dpnp.moveaxis(lu, 0, -1) + + b_orig_shape = b.shape + if b.ndim > 2: + b = dpnp.reshape(b, (-1, n, nrhs)) + + # Move batch axis to the end (n, nrhs, batch) in Fortran order: + # required by getrs_batch + # and ensures each b[..., i] is F-contiguous for getrs_batch + b = dpnp.moveaxis(b, 0, -1) + + lu_usm_arr = dpnp.get_usm_ndarray(lu) + b_usm_arr = dpnp.get_usm_ndarray(b) + + # dpnp.linalg.lu_factor() returns 0-based pivots to match SciPy, + # convert to 1-based for oneMKL getrs_batch + piv_h = piv + 1 + + _manager = dpu.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + + # oneMKL LAPACK getrs_batch overwrites `lu` + lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type) + + # use DPCTL tensor function to fill the сopy of the input array + # from the input array + ht_ev, lu_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=lu_usm_arr, + dst=lu_h.get_array(), + sycl_queue=lu.sycl_queue, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, lu_copy_ev) + + # oneMKL LAPACK getrs_batch overwrites `b` and assumes fortran-like array + # as input + b_h = dpnp.empty_like(b, order="F", dtype=res_type, usm_type=res_usm_type) + ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=b_usm_arr, + dst=b_h.get_array(), + sycl_queue=b.sycl_queue, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, b_copy_ev) + dep_evs = [lu_copy_ev, b_copy_ev] + + lu_stride = n * n + piv_stride = n + b_stride = n * nrhs + + trans_mkl = _map_trans_to_mkl(trans) + + # Call the LAPACK extension function _getrs_batch + # to solve the system of linear equations with an LU-factored + # coefficient square matrix, with multiple right-hand sides. + ht_ev, getrs_batch_ev = li._getrs_batch( + exec_q, + lu_h.get_array(), + piv_h.get_array(), + b_h.get_array(), + trans_mkl, + n, + nrhs, + lu_stride, + piv_stride, + b_stride, + batch_size, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, getrs_batch_ev) + + # Restore original shape: move batch axis back and reshape + b_h = dpnp.moveaxis(b_h, -1, 0).reshape(b_orig_shape) + + return b_h + + def _is_copy_required(a, res_type): """ Determine if `a` needs to be copied before LU decomposition. @@ -200,6 +330,20 @@ def _is_copy_required(a, res_type): return False +def _map_trans_to_mkl(trans): + """Map SciPy-style trans code (0,1,2) to oneMKL transpose enum.""" + if not isinstance(trans, int): + raise TypeError("`trans` must be an integer") + + if trans == 0: + return li.Transpose.N + if trans == 1: + return li.Transpose.T + if trans == 2: + return li.Transpose.C + raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)") + + def dpnp_lu_factor(a, overwrite_a=False, check_finite=True): """ dpnp_lu_factor(a, overwrite_a=False, check_finite=True) @@ -310,18 +454,9 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): res_type = _common_type(lu, b) - # TODO: add broadcasting - if lu.shape[0] != b.shape[0]: - raise ValueError( - f"Shapes of lu {lu.shape} and b {b.shape} are incompatible" - ) - if b.size == 0: return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type) - if lu.ndim > 2: - raise NotImplementedError("Batched matrices are not supported") - if check_finite: if not dpnp.isfinite(lu).all(): raise ValueError( @@ -334,6 +469,16 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): "Right-hand side array must not contain infs or NaNs" ) + if lu.ndim > 2: + # SciPy always copies each 2D slice, + # so `overwrite_b` is ignored here + return _batched_lu_solve(lu, piv, b, trans=trans, res_type=res_type) + + if lu.shape[0] != b.shape[0]: + raise ValueError( + f"Shapes of lu {lu.shape} and b {b.shape} are incompatible" + ) + lu_usm_arr = dpnp.get_usm_ndarray(lu) b_usm_arr = dpnp.get_usm_ndarray(b) @@ -344,7 +489,7 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): _manager = dpu.SequentialOrderManager[exec_q] dep_evs = _manager.submitted_events - # oneMKL LAPACK getrs overwrites `lu`. + # oneMKL LAPACK getrs_batch overwrites `lu`. lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type) # use DPCTL tensor function to fill the сopy of the input array @@ -380,18 +525,7 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): b_h = b dep_evs = [lu_copy_ev] - if not isinstance(trans, int): - raise TypeError("`trans` must be an integer") - - # Map SciPy-style trans codes (0, 1, 2) to MKL transpose enums - if trans == 0: - trans_mkl = li.Transpose.N - elif trans == 1: - trans_mkl = li.Transpose.T - elif trans == 2: - trans_mkl = li.Transpose.C - else: - raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)") + trans_mkl = _map_trans_to_mkl(trans) # Call the LAPACK extension function _getrs # to solve the system of linear equations with an LU-factored diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index 7c0753b96fe..55031706ea3 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2312,9 +2312,6 @@ def test_strided_rhs(self): (4,), (4, 1), (4, 3), - # (1, 4, 3), - # (2, 4, 3), - # (1, 1, 4, 3) ], ) def test_broadcast_rhs(self, b_shape): @@ -2335,19 +2332,15 @@ def test_broadcast_rhs(self, b_shape): assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-5, atol=1e-5) - @pytest.mark.parametrize("shape", [(0, 0), (0, 5), (5, 5)]) - @pytest.mark.parametrize("rhs_cols", [None, 0, 3]) - def test_empty_shapes(self, shape, rhs_cols): - a_dp = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F") - if min(shape) > 0: - for i in range(min(shape)): - a_dp[i, i] = a_dp.dtype.type(1.0) + @pytest.mark.parametrize("a_shape", [(0, 0), (5, 5)]) + @pytest.mark.parametrize("b_shape", [(0,), (0, 0), (0, 5)]) + def test_empty_shapes(self, a_shape, b_shape): + a_dp = dpnp.empty(a_shape, dtype=dpnp.default_float_type(), order="F") + n = a_shape[0] - n = shape[0] - if rhs_cols is None: - b_shape = (n,) - else: - b_shape = (n, rhs_cols) + if n > 0: + for i in range(n): + a_dp[i, i] = a_dp.dtype.type(1.0) b_dp = dpnp.empty(b_shape, dtype=dpnp.default_float_type(), order="F") lu, piv = dpnp.scipy.linalg.lu_factor(a_dp, check_finite=False) @@ -2370,6 +2363,241 @@ def test_check_finite_raises(self, bad): ) +class TestLuSolveBatched: + @staticmethod + def _make_nonsingular_nd_np(shape, dtype, order): + A = generate_random_numpy_array(shape, dtype, order) + n = shape[-1] + A3 = A.reshape((-1, n, n)) + for B in A3: + off = numpy.sum(numpy.abs(B), axis=1) - numpy.abs(numpy.diag(B)) + B[numpy.arange(n), numpy.arange(n)] = A.dtype.type(off + 1.0) + A = A3.reshape(shape) + # Ensure reshapes did not break memory order + A = numpy.array(A, order=order) + return A + + @staticmethod + def _expected_x_shape(a_shape, b_shape): + n = a_shape[-1] + assert a_shape[-2] == n + + a_batch = a_shape[:-2] + if len(b_shape) >= 2 and b_shape[-2] == n: + # b : (..., n, nrhs) + k = b_shape[-1] + b_batch = b_shape[:-2] + exp_batch = numpy.broadcast_shapes(a_batch, b_batch) + return exp_batch + (n, k) + else: + # b : (..., n) + assert b_shape[-1] == n, "b's last dim must equal n" + b_batch = b_shape[:-1] + exp_batch = numpy.broadcast_shapes(a_batch, b_batch) + return exp_batch + (n,) + + @pytest.mark.parametrize( + "a_shape, b_shape", + [ + ((1, 2, 2), (2,)), + ((2, 4, 4), (4,)), + ((2, 4, 4), (4, 3)), + ((2, 4, 4), (2, 4, 4)), + ((2, 4, 4), (1, 4, 3)), + ((2, 4, 4), (2, 4, 2)), + ((2, 3, 4, 4), (1, 3, 4, 2)), + ((2, 3, 4, 4), (2, 1, 4, 2)), + ((3, 4, 4), (1, 1, 4, 2)), + ], + ) + @pytest.mark.parametrize("order", ["C", "F"]) + @pytest.mark.parametrize( + "dtype", get_all_dtypes(no_bool=True, no_none=True) + ) + def test_lu_solve_batched(self, a_shape, b_shape, dtype, order): + a_np = self._make_nonsingular_nd_np(a_shape, dtype, order) + a_dp = dpnp.array(a_np, order=order) + + b_np = generate_random_numpy_array(b_shape, dtype, order) + b_dp = dpnp.array(b_np, order=order) + + lu, piv = dpnp.scipy.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.scipy.linalg.lu_solve( + (lu, piv), b_dp, overwrite_b=True, check_finite=False + ) + + exp_shape = self._expected_x_shape(a_shape, b_shape) + assert x.shape == exp_shape + + if b_dp.ndim > 1: + Ax = a_dp @ x + else: + Ax = (a_dp @ x[..., None])[..., 0] + b_exp = dpnp.broadcast_to(b_dp, exp_shape) + assert dpnp.allclose(Ax, b_exp, rtol=1e-5, atol=1e-5) + + @pytest.mark.parametrize("trans", [0, 1, 2]) + @pytest.mark.parametrize("order", ["C", "F"]) + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_trans(self, trans, order, dtype): + a_shape = (3, 4, 4) + b_shape = (3, 4, 2) + + a_np = self._make_nonsingular_nd_np(a_shape, dtype, order) + a_dp = dpnp.array(a_np, order=order) + b_dp = dpnp.array( + generate_random_numpy_array(b_shape, dtype, order), order=order + ) + + lu, piv = dpnp.scipy.linalg.lu_factor( + a_dp, overwrite_a=False, check_finite=False + ) + x = dpnp.scipy.linalg.lu_solve( + (lu, piv), b_dp, trans=trans, overwrite_b=False, check_finite=False + ) + + if trans == 0: + lhs = a_dp @ x + elif trans == 1: + lhs = dpnp.swapaxes(a_dp, -1, -2) @ x + else: # trans == 2 + lhs = dpnp.conj(dpnp.swapaxes(a_dp, -1, -2)) @ x + + assert dpnp.allclose(lhs, b_dp, rtol=1e-5, atol=1e-5) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + @pytest.mark.parametrize("order", ["C", "F"]) + def test_overwrite(self, dtype, order): + a_np = self._make_nonsingular_nd_np((2, 4, 4), dtype, order) + a_dp = dpnp.array(a_np, order=order) + + lu, piv = dpnp.scipy.linalg.lu_factor( + a_dp, overwrite_a=False, check_finite=False + ) + + b_dp = dpnp.array( + generate_random_numpy_array((2, 4, 2), dtype, "F"), order="F" + ) + b_dp_orig = b_dp.copy() + x = dpnp.scipy.linalg.lu_solve( + (lu, piv), b_dp, overwrite_b=True, check_finite=False + ) + + assert x is not b_dp + assert dpnp.allclose(b_dp, b_dp_orig) + + assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-5, atol=1e-5) + + def test_strided(self): + n, B = 4, 6 + a_np = self._make_nonsingular_nd_np( + (B, n, n), dpnp.default_float_type(), "F" + ) + a_dp = dpnp.array(a_np, order="F") + + a_stride = a_dp[::2] + rhs_full = ( + dpnp.arange(B * n * 3, dtype=dpnp.default_float_type()).reshape( + B, n, 3, order="F" + ) + + 1.0 + ) + b_dp = rhs_full[::2, :, ::-1] + + lu, piv = dpnp.scipy.linalg.lu_factor(a_stride, check_finite=False) + x = dpnp.scipy.linalg.lu_solve( + (lu, piv), b_dp, overwrite_b=False, check_finite=False + ) + + assert dpnp.allclose(a_stride @ x, b_dp, rtol=1e-5, atol=1e-5) + + @pytest.mark.parametrize( + "dtype_a", get_all_dtypes(no_bool=True, no_none=True) + ) + @pytest.mark.parametrize( + "dtype_b", get_all_dtypes(no_bool=True, no_none=True) + ) + @pytest.mark.parametrize("b_shape", [(4, 2), (1, 4, 2), (2, 4, 2)]) + def test_diff_type(self, dtype_a, dtype_b, b_shape): + B, n, k = 2, 4, 2 + a_np = self._make_nonsingular_nd_np((B, n, n), dtype_a, "F") + a_dp = dpnp.array(a_np, order="F") + + b_np = generate_random_numpy_array(b_shape, dtype_b, "F") + b_dp = dpnp.array(b_np, order="F") + + lu, piv = dpnp.scipy.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.scipy.linalg.lu_solve((lu, piv), b_dp, check_finite=False) + + exp_shape = (B, n, k) + assert x.shape == exp_shape + + b_exp = dpnp.broadcast_to(b_dp, exp_shape) + assert dpnp.allclose( + a_dp @ x, b_exp.astype(x.dtype, copy=False), rtol=1e-5, atol=1e-5 + ) + + @pytest.mark.parametrize( + "a_shape, b_shape", + [ + ((0, 3, 3), (0, 3)), + ((2, 0, 0), (2, 0)), + ((0, 0, 0), (0, 0)), + ], + ) + def test_empty_shapes(self, a_shape, b_shape): + a = dpnp.empty(a_shape, dtype=dpnp.default_float_type(), order="F") + b = dpnp.empty(b_shape, dtype=dpnp.default_float_type(), order="F") + + lu, piv = dpnp.scipy.linalg.lu_factor(a, check_finite=False) + x = dpnp.scipy.linalg.lu_solve((lu, piv), b, check_finite=False) + + assert x.shape == b_shape + + def test_check_finite_raises(self): + B, n = 2, 3 + a_np = self._make_nonsingular_nd_np( + (B, n, n), dpnp.default_float_type(), "F" + ) + a_dp = dpnp.array(a_np, order="F") + lu, piv = dpnp.scipy.linalg.lu_factor(a_dp, check_finite=False) + + b_bad = dpnp.ones((B, n), dtype=dpnp.default_float_type(), order="F") + b_bad[1, 0] = dpnp.nan + assert_raises( + ValueError, + dpnp.scipy.linalg.lu_solve, + (lu, piv), + b_bad, + check_finite=True, + ) + + @pytest.mark.parametrize( + "a_shape, b_shape", + [ + ((2, 4, 4), (2,)), + ((2, 4, 4), (2, 4)), + ((2, 4, 4), (4, 4, 2)), + ((2, 4, 4), (2, 3, 4, 2)), + ((2, 3, 4, 4), (3, 4)), + ((2, 3, 4, 4), (2, 4)), + ((2, 3, 4, 4), (2, 3, 5, 2)), + ], + ) + def test_invalid_shapes(self, a_shape, b_shape): + dtype = dpnp.default_float_type() + a = dpnp.array( + self._make_nonsingular_nd_np(a_shape, dtype, "F"), order="F" + ) + b = dpnp.array( + generate_random_numpy_array(b_shape, dtype, "F"), order="F" + ) + + lu, piv = dpnp.scipy.linalg.lu_factor(a, check_finite=False) + with pytest.raises(ValueError): + dpnp.scipy.linalg.lu_solve((lu, piv), b, check_finite=False) + + class TestMatrixPower: @pytest.mark.parametrize("dtype", get_all_dtypes()) @pytest.mark.parametrize( diff --git a/dpnp/tests/test_sycl_queue.py b/dpnp/tests/test_sycl_queue.py index bce599c2279..8a034bc8174 100644 --- a/dpnp/tests/test_sycl_queue.py +++ b/dpnp/tests/test_sycl_queue.py @@ -1610,11 +1610,16 @@ def test_lu_factor(self, data, device): assert_sycl_queue_equal(param_queue, a.sycl_queue) @pytest.mark.parametrize( - "b_data", - [[1.0, 2.0], numpy.empty((2, 0))], + "a_data, b_data", + [ + ([[1.0, 2.0], [3.0, 5.0]], [1.0, 2.0]), + ([[1.0, 2.0], [3.0, 5.0]], numpy.empty((2, 0))), + ([[[1.0, 2.0], [3.0, 5.0]]], [1.0, 2.0]), + ([[[1.0, 2.0], [3.0, 5.0]]], numpy.empty((2, 0, 2))), + ], ) - def test_lu_solve(self, b_data, device): - a = dpnp.array([[1.0, 2.0], [3.0, 5.0]], device=device) + def test_lu_solve(self, a_data, b_data, device): + a = dpnp.array(a_data, device=device) lu, piv = dpnp.scipy.linalg.lu_factor(a) b = dpnp.array(b_data, device=device) diff --git a/dpnp/tests/test_usm_type.py b/dpnp/tests/test_usm_type.py index eb059e335a3..da1b575166d 100644 --- a/dpnp/tests/test_usm_type.py +++ b/dpnp/tests/test_usm_type.py @@ -1488,11 +1488,16 @@ def test_lu_factor(self, data, usm_type): @pytest.mark.parametrize("usm_type_rhs", list_of_usm_types) @pytest.mark.parametrize( - "b_data", - [[1.0, 2.0], numpy.empty((2, 0))], + "a_data, b_data", + [ + ([[1.0, 2.0], [3.0, 5.0]], [1.0, 2.0]), + ([[1.0, 2.0], [3.0, 5.0]], numpy.empty((2, 0))), + ([[[1.0, 2.0], [3.0, 5.0]]], [1.0, 2.0]), + ([[[1.0, 2.0], [3.0, 5.0]]], numpy.empty((2, 0, 2))), + ], ) - def test_lu_solve(self, b_data, usm_type, usm_type_rhs): - a = dpnp.array([[1.0, 2.0], [3.0, 5.0]], usm_type=usm_type) + def test_lu_solve(self, a_data, b_data, usm_type, usm_type_rhs): + a = dpnp.array(a_data, usm_type=usm_type) lu, piv = dpnp.scipy.linalg.lu_factor(a) b = dpnp.array(b_data, usm_type=usm_type_rhs)