Skip to content

Commit 6d339e9

Browse files
Use oneMKL LAPACK gesv for dpnp.linalg.solve() (#2558)
This PR suggests using oneMKL LAPACK `gesv` instead of `getrf` and `getrs` in `dpnp.linalg.solve()` since the issues in oneMKL have been resolved. This removes the workaround implemented in #1763
1 parent 8133c3c commit 6d339e9

File tree

3 files changed

+16
-46
lines changed

3 files changed

+16
-46
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3737
* Reused dpctl tensor include to enable experimental SYCL namespace for complex types [#2546](https://github.com/IntelPython/dpnp/pull/2546)
3838
* Changed Windows-specific logic in dpnp initialization [#2553](https://github.com/IntelPython/dpnp/pull/2553)
3939
* Added missing includes to files in ufunc and VM pybind11 extensions [#2571](https://github.com/IntelPython/dpnp/pull/2571)
40+
* Refactored backend implementation of `dpnp.linalg.solve` to use oneMKL LAPACK `gesv` directly [#2558](https://github.com/IntelPython/dpnp/pull/2558)
4041

4142
### Deprecated
4243

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 12 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2868,26 +2868,20 @@ def dpnp_solve(a, b):
28682868
a_usm_arr = dpnp.get_usm_ndarray(a)
28692869
b_usm_arr = dpnp.get_usm_ndarray(b)
28702870

2871-
# Due to MKLD-17226 (bug with incorrect checking ldb parameter
2872-
# in oneapi::mkl::lapack::gesv_scratchad_size that raises an error
2873-
# `invalid argument` when nrhs > n) we can not use _gesv directly.
2874-
# This w/a uses _getrf and _getrs instead
2875-
# to handle cases where nrhs > n for a.shape = (n x n)
2876-
# and b.shape = (n x nrhs).
2877-
2878-
# oneMKL LAPACK getrf overwrites `a`.
2879-
a_h = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=res_usm_type)
2871+
# oneMKL LAPACK getrs overwrites `a` and assumes fortran-like array as
2872+
# input
2873+
a_h = dpnp.empty_like(a, order="F", dtype=res_type, usm_type=res_usm_type)
28802874

28812875
_manager = dpu.SequentialOrderManager[exec_q]
2882-
dev_evs = _manager.submitted_events
2876+
dep_evs = _manager.submitted_events
28832877

28842878
# use DPCTL tensor function to fill the сopy of the input array
28852879
# from the input array
28862880
ht_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
28872881
src=a_usm_arr,
28882882
dst=a_h.get_array(),
28892883
sycl_queue=a.sycl_queue,
2890-
depends=dev_evs,
2884+
depends=dep_evs,
28912885
)
28922886
_manager.add_event_pair(ht_ev, a_copy_ev)
28932887

@@ -2903,43 +2897,18 @@ def dpnp_solve(a, b):
29032897
src=b_usm_arr,
29042898
dst=b_h.get_array(),
29052899
sycl_queue=b.sycl_queue,
2906-
depends=dev_evs,
2900+
depends=dep_evs,
29072901
)
29082902
_manager.add_event_pair(ht_ev, b_copy_ev)
29092903

2910-
n = a.shape[0]
2911-
2912-
ipiv_h = dpnp.empty_like(
2913-
a,
2914-
shape=(n,),
2915-
dtype=dpnp.int64,
2904+
# Call the LAPACK extension function _gesv to solve the system of linear
2905+
# equations with the coefficient square matrix and
2906+
# the dependent variables array
2907+
ht_lapack_ev, gesv_ev = li._gesv(
2908+
exec_q, a_h.get_array(), b_h.get_array(), [a_copy_ev, b_copy_ev]
29162909
)
2917-
dev_info_h = [0]
29182910

2919-
# Call the LAPACK extension function _getrf
2920-
# to perform LU decomposition of the input matrix
2921-
ht_ev, getrf_ev = li._getrf(
2922-
exec_q,
2923-
a_h.get_array(),
2924-
ipiv_h.get_array(),
2925-
dev_info_h,
2926-
depends=[a_copy_ev],
2927-
)
2928-
_manager.add_event_pair(ht_ev, getrf_ev)
2929-
2930-
_check_lapack_dev_info(dev_info_h)
2931-
2932-
# Call the LAPACK extension function _getrs
2933-
# to solve the system of linear equations with an LU-factored
2934-
# coefficient square matrix, with multiple right-hand sides.
2935-
ht_ev, getrs_ev = li._getrs(
2936-
exec_q,
2937-
a_h.get_array(),
2938-
ipiv_h.get_array(),
2939-
b_h.get_array(),
2940-
depends=[b_copy_ev, getrf_ev],
2941-
)
2942-
_manager.add_event_pair(ht_ev, getrs_ev)
2911+
_manager.add_event_pair(ht_lapack_ev, gesv_ev)
29432912
return b_h
29442913

29452914

dpnp/tests/test_linalg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2841,7 +2841,7 @@ def test_solve(self, dtype):
28412841
expected = numpy.linalg.solve(a_np, a_np)
28422842
result = dpnp.linalg.solve(a_dp, a_dp)
28432843

2844-
assert_allclose(result, expected)
2844+
assert_dtype_allclose(result, expected)
28452845

28462846
@testing.with_requires("numpy>=2.0")
28472847
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
@@ -2914,12 +2914,12 @@ def test_solve_strides(self):
29142914
# positive strides
29152915
expected = numpy.linalg.solve(a_np[::2, ::2], b_np[::2])
29162916
result = dpnp.linalg.solve(a_dp[::2, ::2], b_dp[::2])
2917-
assert_allclose(result, expected, rtol=1e-6)
2917+
assert_dtype_allclose(result, expected)
29182918

29192919
# negative strides
29202920
expected = numpy.linalg.solve(a_np[::-2, ::-2], b_np[::-2])
29212921
result = dpnp.linalg.solve(a_dp[::-2, ::-2], b_dp[::-2])
2922-
assert_allclose(result, expected)
2922+
assert_dtype_allclose(result, expected)
29232923

29242924
@pytest.mark.parametrize(
29252925
"matrix, vector",

0 commit comments

Comments
 (0)