Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Replaced `ci` section in `.pre-commit-config.yaml` with a new GitHub workflow with scheduled run to autoupdate the `pre-commit` configuration [#2542](https://github.com/IntelPython/dpnp/pull/2542)
* FFT module is updated to perform in-place FFT in intermediate steps of ND FFT [#2543](https://github.com/IntelPython/dpnp/pull/2543)
* Reused dpctl tensor include to enable experimental SYCL namespace for complex types [#2546](https://github.com/IntelPython/dpnp/pull/2546)
* Refactored backend implementation of `dpnp.linalg.solve` to use oneMKL LAPACK `gesv` directly [#2558](https://github.com/IntelPython/dpnp/pull/2558)

### Deprecated

Expand Down
49 changes: 9 additions & 40 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2619,15 +2619,9 @@ def dpnp_solve(a, b):
a_usm_arr = dpnp.get_usm_ndarray(a)
b_usm_arr = dpnp.get_usm_ndarray(b)

# Due to MKLD-17226 (bug with incorrect checking ldb parameter
# in oneapi::mkl::lapack::gesv_scratchad_size that raises an error
# `invalid argument` when nrhs > n) we can not use _gesv directly.
# This w/a uses _getrf and _getrs instead
# to handle cases where nrhs > n for a.shape = (n x n)
# and b.shape = (n x nrhs).

# oneMKL LAPACK getrf overwrites `a`.
a_h = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=res_usm_type)
# oneMKL LAPACK getrs overwrites `a` and assumes fortran-like array as
# input
a_h = dpnp.empty_like(a, order="F", dtype=res_type, usm_type=res_usm_type)

_manager = dpu.SequentialOrderManager[exec_q]
dev_evs = _manager.submitted_events
Expand Down Expand Up @@ -2658,39 +2652,14 @@ def dpnp_solve(a, b):
)
_manager.add_event_pair(ht_ev, b_copy_ev)

n = a.shape[0]

ipiv_h = dpnp.empty_like(
a,
shape=(n,),
dtype=dpnp.int64,
)
dev_info_h = [0]

# Call the LAPACK extension function _getrf
# to perform LU decomposition of the input matrix
ht_ev, getrf_ev = li._getrf(
exec_q,
a_h.get_array(),
ipiv_h.get_array(),
dev_info_h,
depends=[a_copy_ev],
# Call the LAPACK extension function _gesv to solve the system of linear
# equations with the coefficient square matrix and
# the dependent variables array.
ht_lapack_ev, gesv_ev = li._gesv(
exec_q, a_h.get_array(), b_h.get_array(), [a_copy_ev, b_copy_ev]
)
_manager.add_event_pair(ht_ev, getrf_ev)

_check_lapack_dev_info(dev_info_h)

# Call the LAPACK extension function _getrs
# to solve the system of linear equations with an LU-factored
# coefficient square matrix, with multiple right-hand sides.
ht_ev, getrs_ev = li._getrs(
exec_q,
a_h.get_array(),
ipiv_h.get_array(),
b_h.get_array(),
depends=[b_copy_ev, getrf_ev],
)
_manager.add_event_pair(ht_ev, getrs_ev)
_manager.add_event_pair(ht_lapack_ev, gesv_ev)
return b_h


Expand Down
Loading