Skip to content

Commit 6797cf3

Browse files
Use gesv instead of getrf/getrs in dpnp_solve
1 parent 230ba6a commit 6797cf3

File tree

1 file changed

+9
-40
lines changed

1 file changed

+9
-40
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 9 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2619,15 +2619,9 @@ def dpnp_solve(a, b):
26192619
a_usm_arr = dpnp.get_usm_ndarray(a)
26202620
b_usm_arr = dpnp.get_usm_ndarray(b)
26212621

2622-
# Due to MKLD-17226 (bug with incorrect checking ldb parameter
2623-
# in oneapi::mkl::lapack::gesv_scratchad_size that raises an error
2624-
# `invalid argument` when nrhs > n) we can not use _gesv directly.
2625-
# This w/a uses _getrf and _getrs instead
2626-
# to handle cases where nrhs > n for a.shape = (n x n)
2627-
# and b.shape = (n x nrhs).
2628-
2629-
# oneMKL LAPACK getrf overwrites `a`.
2630-
a_h = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=res_usm_type)
2622+
# oneMKL LAPACK getrs overwrites `a` and assumes fortran-like array as
2623+
# input
2624+
a_h = dpnp.empty_like(a, order="F", dtype=res_type, usm_type=res_usm_type)
26312625

26322626
_manager = dpu.SequentialOrderManager[exec_q]
26332627
dev_evs = _manager.submitted_events
@@ -2658,39 +2652,14 @@ def dpnp_solve(a, b):
26582652
)
26592653
_manager.add_event_pair(ht_ev, b_copy_ev)
26602654

2661-
n = a.shape[0]
2662-
2663-
ipiv_h = dpnp.empty_like(
2664-
a,
2665-
shape=(n,),
2666-
dtype=dpnp.int64,
2667-
)
2668-
dev_info_h = [0]
2669-
2670-
# Call the LAPACK extension function _getrf
2671-
# to perform LU decomposition of the input matrix
2672-
ht_ev, getrf_ev = li._getrf(
2673-
exec_q,
2674-
a_h.get_array(),
2675-
ipiv_h.get_array(),
2676-
dev_info_h,
2677-
depends=[a_copy_ev],
2655+
# Call the LAPACK extension function _gesv to solve the system of linear
2656+
# equations with the coefficient square matrix and
2657+
# the dependent variables array.
2658+
ht_lapack_ev, gesv_ev = li._gesv(
2659+
exec_q, a_h.get_array(), b_h.get_array(), [a_copy_ev, b_copy_ev]
26782660
)
2679-
_manager.add_event_pair(ht_ev, getrf_ev)
26802661

2681-
_check_lapack_dev_info(dev_info_h)
2682-
2683-
# Call the LAPACK extension function _getrs
2684-
# to solve the system of linear equations with an LU-factored
2685-
# coefficient square matrix, with multiple right-hand sides.
2686-
ht_ev, getrs_ev = li._getrs(
2687-
exec_q,
2688-
a_h.get_array(),
2689-
ipiv_h.get_array(),
2690-
b_h.get_array(),
2691-
depends=[b_copy_ev, getrf_ev],
2692-
)
2693-
_manager.add_event_pair(ht_ev, getrs_ev)
2662+
_manager.add_event_pair(ht_lapack_ev, gesv_ev)
26942663
return b_h
26952664

26962665

0 commit comments

Comments
 (0)