@@ -2619,15 +2619,9 @@ def dpnp_solve(a, b):
26192619 a_usm_arr = dpnp .get_usm_ndarray (a )
26202620 b_usm_arr = dpnp .get_usm_ndarray (b )
26212621
2622- # Due to MKLD-17226 (bug with incorrect checking ldb parameter
2623- # in oneapi::mkl::lapack::gesv_scratchad_size that raises an error
2624- # `invalid argument` when nrhs > n) we can not use _gesv directly.
2625- # This w/a uses _getrf and _getrs instead
2626- # to handle cases where nrhs > n for a.shape = (n x n)
2627- # and b.shape = (n x nrhs).
2628-
2629- # oneMKL LAPACK getrf overwrites `a`.
2630- a_h = dpnp .empty_like (a , order = "C" , dtype = res_type , usm_type = res_usm_type )
2622+ # oneMKL LAPACK getrs overwrites `a` and assumes fortran-like array as
2623+ # input
2624+ a_h = dpnp .empty_like (a , order = "F" , dtype = res_type , usm_type = res_usm_type )
26312625
26322626 _manager = dpu .SequentialOrderManager [exec_q ]
26332627 dev_evs = _manager .submitted_events
@@ -2658,39 +2652,14 @@ def dpnp_solve(a, b):
26582652 )
26592653 _manager .add_event_pair (ht_ev , b_copy_ev )
26602654
2661- n = a .shape [0 ]
2662-
2663- ipiv_h = dpnp .empty_like (
2664- a ,
2665- shape = (n ,),
2666- dtype = dpnp .int64 ,
2667- )
2668- dev_info_h = [0 ]
2669-
2670- # Call the LAPACK extension function _getrf
2671- # to perform LU decomposition of the input matrix
2672- ht_ev , getrf_ev = li ._getrf (
2673- exec_q ,
2674- a_h .get_array (),
2675- ipiv_h .get_array (),
2676- dev_info_h ,
2677- depends = [a_copy_ev ],
2655+ # Call the LAPACK extension function _gesv to solve the system of linear
2656+ # equations with the coefficient square matrix and
2657+ # the dependent variables array.
2658+ ht_lapack_ev , gesv_ev = li ._gesv (
2659+ exec_q , a_h .get_array (), b_h .get_array (), [a_copy_ev , b_copy_ev ]
26782660 )
2679- _manager .add_event_pair (ht_ev , getrf_ev )
26802661
2681- _check_lapack_dev_info (dev_info_h )
2682-
2683- # Call the LAPACK extension function _getrs
2684- # to solve the system of linear equations with an LU-factored
2685- # coefficient square matrix, with multiple right-hand sides.
2686- ht_ev , getrs_ev = li ._getrs (
2687- exec_q ,
2688- a_h .get_array (),
2689- ipiv_h .get_array (),
2690- b_h .get_array (),
2691- depends = [b_copy_ev , getrf_ev ],
2692- )
2693- _manager .add_event_pair (ht_ev , getrs_ev )
2662+ _manager .add_event_pair (ht_lapack_ev , gesv_ev )
26942663 return b_h
26952664
26962665
0 commit comments