@@ -2619,15 +2619,9 @@ def dpnp_solve(a, b):
2619
2619
a_usm_arr = dpnp .get_usm_ndarray (a )
2620
2620
b_usm_arr = dpnp .get_usm_ndarray (b )
2621
2621
2622
- # Due to MKLD-17226 (bug with incorrect checking ldb parameter
2623
- # in oneapi::mkl::lapack::gesv_scratchad_size that raises an error
2624
- # `invalid argument` when nrhs > n) we can not use _gesv directly.
2625
- # This w/a uses _getrf and _getrs instead
2626
- # to handle cases where nrhs > n for a.shape = (n x n)
2627
- # and b.shape = (n x nrhs).
2628
-
2629
- # oneMKL LAPACK getrf overwrites `a`.
2630
- a_h = dpnp .empty_like (a , order = "C" , dtype = res_type , usm_type = res_usm_type )
2622
+ # oneMKL LAPACK getrs overwrites `a` and assumes fortran-like array as
2623
+ # input
2624
+ a_h = dpnp .empty_like (a , order = "F" , dtype = res_type , usm_type = res_usm_type )
2631
2625
2632
2626
_manager = dpu .SequentialOrderManager [exec_q ]
2633
2627
dev_evs = _manager .submitted_events
@@ -2658,39 +2652,14 @@ def dpnp_solve(a, b):
2658
2652
)
2659
2653
_manager .add_event_pair (ht_ev , b_copy_ev )
2660
2654
2661
- n = a .shape [0 ]
2662
-
2663
- ipiv_h = dpnp .empty_like (
2664
- a ,
2665
- shape = (n ,),
2666
- dtype = dpnp .int64 ,
2667
- )
2668
- dev_info_h = [0 ]
2669
-
2670
- # Call the LAPACK extension function _getrf
2671
- # to perform LU decomposition of the input matrix
2672
- ht_ev , getrf_ev = li ._getrf (
2673
- exec_q ,
2674
- a_h .get_array (),
2675
- ipiv_h .get_array (),
2676
- dev_info_h ,
2677
- depends = [a_copy_ev ],
2655
+ # Call the LAPACK extension function _gesv to solve the system of linear
2656
+ # equations with the coefficient square matrix and
2657
+ # the dependent variables array.
2658
+ ht_lapack_ev , gesv_ev = li ._gesv (
2659
+ exec_q , a_h .get_array (), b_h .get_array (), [a_copy_ev , b_copy_ev ]
2678
2660
)
2679
- _manager .add_event_pair (ht_ev , getrf_ev )
2680
2661
2681
- _check_lapack_dev_info (dev_info_h )
2682
-
2683
- # Call the LAPACK extension function _getrs
2684
- # to solve the system of linear equations with an LU-factored
2685
- # coefficient square matrix, with multiple right-hand sides.
2686
- ht_ev , getrs_ev = li ._getrs (
2687
- exec_q ,
2688
- a_h .get_array (),
2689
- ipiv_h .get_array (),
2690
- b_h .get_array (),
2691
- depends = [b_copy_ev , getrf_ev ],
2692
- )
2693
- _manager .add_event_pair (ht_ev , getrs_ev )
2662
+ _manager .add_event_pair (ht_lapack_ev , gesv_ev )
2694
2663
return b_h
2695
2664
2696
2665
0 commit comments