@@ -2868,26 +2868,20 @@ def dpnp_solve(a, b):
2868
2868
a_usm_arr = dpnp .get_usm_ndarray (a )
2869
2869
b_usm_arr = dpnp .get_usm_ndarray (b )
2870
2870
2871
- # Due to MKLD-17226 (bug with incorrect checking ldb parameter
2872
- # in oneapi::mkl::lapack::gesv_scratchad_size that raises an error
2873
- # `invalid argument` when nrhs > n) we can not use _gesv directly.
2874
- # This w/a uses _getrf and _getrs instead
2875
- # to handle cases where nrhs > n for a.shape = (n x n)
2876
- # and b.shape = (n x nrhs).
2877
-
2878
- # oneMKL LAPACK getrf overwrites `a`.
2879
- a_h = dpnp .empty_like (a , order = "C" , dtype = res_type , usm_type = res_usm_type )
2871
+ # oneMKL LAPACK getrs overwrites `a` and assumes fortran-like array as
2872
+ # input
2873
+ a_h = dpnp .empty_like (a , order = "F" , dtype = res_type , usm_type = res_usm_type )
2880
2874
2881
2875
_manager = dpu .SequentialOrderManager [exec_q ]
2882
- dev_evs = _manager .submitted_events
2876
+ dep_evs = _manager .submitted_events
2883
2877
2884
2878
# use DPCTL tensor function to fill the сopy of the input array
2885
2879
# from the input array
2886
2880
ht_ev , a_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
2887
2881
src = a_usm_arr ,
2888
2882
dst = a_h .get_array (),
2889
2883
sycl_queue = a .sycl_queue ,
2890
- depends = dev_evs ,
2884
+ depends = dep_evs ,
2891
2885
)
2892
2886
_manager .add_event_pair (ht_ev , a_copy_ev )
2893
2887
@@ -2903,43 +2897,18 @@ def dpnp_solve(a, b):
2903
2897
src = b_usm_arr ,
2904
2898
dst = b_h .get_array (),
2905
2899
sycl_queue = b .sycl_queue ,
2906
- depends = dev_evs ,
2900
+ depends = dep_evs ,
2907
2901
)
2908
2902
_manager .add_event_pair (ht_ev , b_copy_ev )
2909
2903
2910
- n = a .shape [0 ]
2911
-
2912
- ipiv_h = dpnp .empty_like (
2913
- a ,
2914
- shape = (n ,),
2915
- dtype = dpnp .int64 ,
2904
+ # Call the LAPACK extension function _gesv to solve the system of linear
2905
+ # equations with the coefficient square matrix and
2906
+ # the dependent variables array
2907
+ ht_lapack_ev , gesv_ev = li ._gesv (
2908
+ exec_q , a_h .get_array (), b_h .get_array (), [a_copy_ev , b_copy_ev ]
2916
2909
)
2917
- dev_info_h = [0 ]
2918
2910
2919
- # Call the LAPACK extension function _getrf
2920
- # to perform LU decomposition of the input matrix
2921
- ht_ev , getrf_ev = li ._getrf (
2922
- exec_q ,
2923
- a_h .get_array (),
2924
- ipiv_h .get_array (),
2925
- dev_info_h ,
2926
- depends = [a_copy_ev ],
2927
- )
2928
- _manager .add_event_pair (ht_ev , getrf_ev )
2929
-
2930
- _check_lapack_dev_info (dev_info_h )
2931
-
2932
- # Call the LAPACK extension function _getrs
2933
- # to solve the system of linear equations with an LU-factored
2934
- # coefficient square matrix, with multiple right-hand sides.
2935
- ht_ev , getrs_ev = li ._getrs (
2936
- exec_q ,
2937
- a_h .get_array (),
2938
- ipiv_h .get_array (),
2939
- b_h .get_array (),
2940
- depends = [b_copy_ev , getrf_ev ],
2941
- )
2942
- _manager .add_event_pair (ht_ev , getrs_ev )
2911
+ _manager .add_event_pair (ht_lapack_ev , gesv_ev )
2943
2912
return b_h
2944
2913
2945
2914
0 commit comments