Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Reused dpctl tensor include to enable experimental SYCL namespace for complex types [#2546](https://github.com/IntelPython/dpnp/pull/2546)
* Changed Windows-specific logic in dpnp initialization [#2553](https://github.com/IntelPython/dpnp/pull/2553)
* Added missing includes to files in ufunc and VM pybind11 extensions [#2571](https://github.com/IntelPython/dpnp/pull/2571)
* Refactored backend implementation of `dpnp.linalg.solve` to use oneMKL LAPACK `gesv` directly [#2558](https://github.com/IntelPython/dpnp/pull/2558)

### Deprecated

Expand Down
49 changes: 9 additions & 40 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2868,15 +2868,9 @@ def dpnp_solve(a, b):
a_usm_arr = dpnp.get_usm_ndarray(a)
b_usm_arr = dpnp.get_usm_ndarray(b)

# Due to MKLD-17226 (bug with incorrect checking ldb parameter
# in oneapi::mkl::lapack::gesv_scratchad_size that raises an error
# `invalid argument` when nrhs > n) we can not use _gesv directly.
# This w/a uses _getrf and _getrs instead
# to handle cases where nrhs > n for a.shape = (n x n)
# and b.shape = (n x nrhs).

# oneMKL LAPACK getrf overwrites `a`.
a_h = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=res_usm_type)
# oneMKL LAPACK getrs overwrites `a` and assumes fortran-like array as
# input
a_h = dpnp.empty_like(a, order="F", dtype=res_type, usm_type=res_usm_type)

_manager = dpu.SequentialOrderManager[exec_q]
dev_evs = _manager.submitted_events
Expand Down Expand Up @@ -2907,39 +2901,14 @@ def dpnp_solve(a, b):
)
_manager.add_event_pair(ht_ev, b_copy_ev)

n = a.shape[0]

ipiv_h = dpnp.empty_like(
a,
shape=(n,),
dtype=dpnp.int64,
)
dev_info_h = [0]

# Call the LAPACK extension function _getrf
# to perform LU decomposition of the input matrix
ht_ev, getrf_ev = li._getrf(
exec_q,
a_h.get_array(),
ipiv_h.get_array(),
dev_info_h,
depends=[a_copy_ev],
# Call the LAPACK extension function _gesv to solve the system of linear
# equations with the coefficient square matrix and
# the dependent variables array.
ht_lapack_ev, gesv_ev = li._gesv(
exec_q, a_h.get_array(), b_h.get_array(), [a_copy_ev, b_copy_ev]
)
_manager.add_event_pair(ht_ev, getrf_ev)

_check_lapack_dev_info(dev_info_h)

# Call the LAPACK extension function _getrs
# to solve the system of linear equations with an LU-factored
# coefficient square matrix, with multiple right-hand sides.
ht_ev, getrs_ev = li._getrs(
exec_q,
a_h.get_array(),
ipiv_h.get_array(),
b_h.get_array(),
depends=[b_copy_ev, getrf_ev],
)
_manager.add_event_pair(ht_ev, getrs_ev)
_manager.add_event_pair(ht_lapack_ev, gesv_ev)
return b_h


Expand Down
6 changes: 3 additions & 3 deletions dpnp/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2841,7 +2841,7 @@ def test_solve(self, dtype):
expected = numpy.linalg.solve(a_np, a_np)
result = dpnp.linalg.solve(a_dp, a_dp)

assert_allclose(result, expected)
assert_dtype_allclose(result, expected)

@testing.with_requires("numpy>=2.0")
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
Expand Down Expand Up @@ -2914,12 +2914,12 @@ def test_solve_strides(self):
# positive strides
expected = numpy.linalg.solve(a_np[::2, ::2], b_np[::2])
result = dpnp.linalg.solve(a_dp[::2, ::2], b_dp[::2])
assert_allclose(result, expected, rtol=1e-6)
assert_dtype_allclose(result, expected)

# negative strides
expected = numpy.linalg.solve(a_np[::-2, ::-2], b_np[::-2])
result = dpnp.linalg.solve(a_dp[::-2, ::-2], b_dp[::-2])
assert_allclose(result, expected)
assert_dtype_allclose(result, expected)

@pytest.mark.parametrize(
"matrix, vector",
Expand Down
Loading