diff --git a/CHANGELOG.md b/CHANGELOG.md index 708983f5fa6..c8e15f9a79e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Reused dpctl tensor include to enable experimental SYCL namespace for complex types [#2546](https://github.com/IntelPython/dpnp/pull/2546) * Changed Windows-specific logic in dpnp initialization [#2553](https://github.com/IntelPython/dpnp/pull/2553) * Added missing includes to files in ufunc and VM pybind11 extensions [#2571](https://github.com/IntelPython/dpnp/pull/2571) +* Refactored backend implementation of `dpnp.linalg.solve` to use oneMKL LAPACK `gesv` directly [#2558](https://github.com/IntelPython/dpnp/pull/2558) ### Deprecated diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 2b8eef552aa..fdf46174bfc 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -2868,18 +2868,12 @@ def dpnp_solve(a, b): a_usm_arr = dpnp.get_usm_ndarray(a) b_usm_arr = dpnp.get_usm_ndarray(b) - # Due to MKLD-17226 (bug with incorrect checking ldb parameter - # in oneapi::mkl::lapack::gesv_scratchad_size that raises an error - # `invalid argument` when nrhs > n) we can not use _gesv directly. - # This w/a uses _getrf and _getrs instead - # to handle cases where nrhs > n for a.shape = (n x n) - # and b.shape = (n x nrhs). - - # oneMKL LAPACK getrf overwrites `a`. - a_h = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=res_usm_type) + # oneMKL LAPACK getrs overwrites `a` and assumes fortran-like array as + # input + a_h = dpnp.empty_like(a, order="F", dtype=res_type, usm_type=res_usm_type) _manager = dpu.SequentialOrderManager[exec_q] - dev_evs = _manager.submitted_events + dep_evs = _manager.submitted_events # use DPCTL tensor function to fill the сopy of the input array # from the input array @@ -2887,7 +2881,7 @@ def dpnp_solve(a, b): src=a_usm_arr, dst=a_h.get_array(), sycl_queue=a.sycl_queue, - depends=dev_evs, + depends=dep_evs, ) _manager.add_event_pair(ht_ev, a_copy_ev) @@ -2903,43 +2897,18 @@ def dpnp_solve(a, b): src=b_usm_arr, dst=b_h.get_array(), sycl_queue=b.sycl_queue, - depends=dev_evs, + depends=dep_evs, ) _manager.add_event_pair(ht_ev, b_copy_ev) - n = a.shape[0] - - ipiv_h = dpnp.empty_like( - a, - shape=(n,), - dtype=dpnp.int64, + # Call the LAPACK extension function _gesv to solve the system of linear + # equations with the coefficient square matrix and + # the dependent variables array + ht_lapack_ev, gesv_ev = li._gesv( + exec_q, a_h.get_array(), b_h.get_array(), [a_copy_ev, b_copy_ev] ) - dev_info_h = [0] - # Call the LAPACK extension function _getrf - # to perform LU decomposition of the input matrix - ht_ev, getrf_ev = li._getrf( - exec_q, - a_h.get_array(), - ipiv_h.get_array(), - dev_info_h, - depends=[a_copy_ev], - ) - _manager.add_event_pair(ht_ev, getrf_ev) - - _check_lapack_dev_info(dev_info_h) - - # Call the LAPACK extension function _getrs - # to solve the system of linear equations with an LU-factored - # coefficient square matrix, with multiple right-hand sides. - ht_ev, getrs_ev = li._getrs( - exec_q, - a_h.get_array(), - ipiv_h.get_array(), - b_h.get_array(), - depends=[b_copy_ev, getrf_ev], - ) - _manager.add_event_pair(ht_ev, getrs_ev) + _manager.add_event_pair(ht_lapack_ev, gesv_ev) return b_h diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index ca57ef7da53..bef8159e6e9 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2841,7 +2841,7 @@ def test_solve(self, dtype): expected = numpy.linalg.solve(a_np, a_np) result = dpnp.linalg.solve(a_dp, a_dp) - assert_allclose(result, expected) + assert_dtype_allclose(result, expected) @testing.with_requires("numpy>=2.0") @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) @@ -2914,12 +2914,12 @@ def test_solve_strides(self): # positive strides expected = numpy.linalg.solve(a_np[::2, ::2], b_np[::2]) result = dpnp.linalg.solve(a_dp[::2, ::2], b_dp[::2]) - assert_allclose(result, expected, rtol=1e-6) + assert_dtype_allclose(result, expected) # negative strides expected = numpy.linalg.solve(a_np[::-2, ::-2], b_np[::-2]) result = dpnp.linalg.solve(a_dp[::-2, ::-2], b_dp[::-2]) - assert_allclose(result, expected) + assert_dtype_allclose(result, expected) @pytest.mark.parametrize( "matrix, vector",