@@ -2518,9 +2518,12 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
2518
2518
)
2519
2519
2520
2520
lu_usm_arr = dpnp .get_usm_ndarray (lu )
2521
- piv_usm_arr = dpnp .get_usm_ndarray (piv )
2522
2521
b_usm_arr = dpnp .get_usm_ndarray (b )
2523
2522
2523
+ # dpnp.linalg.lu_factor() returns 0-based pivots to match SciPy,
2524
+ # convert to 1-based for oneMKL getrs
2525
+ piv_h = piv + 1
2526
+
2524
2527
_manager = dpu .SequentialOrderManager [exec_q ]
2525
2528
dep_evs = _manager .submitted_events
2526
2529
@@ -2537,19 +2540,6 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
2537
2540
)
2538
2541
_manager .add_event_pair (ht_ev , lu_copy_ev )
2539
2542
2540
- # oneMKL LAPACK getrf overwrites `piv`.
2541
- piv_h = dpnp .empty_like (piv , order = "F" , usm_type = res_usm_type )
2542
-
2543
- # use DPCTL tensor function to fill the сopy of the pivot array
2544
- # from the pivot array
2545
- ht_ev , piv_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
2546
- src = piv_usm_arr ,
2547
- dst = piv_h .get_array (),
2548
- sycl_queue = piv .sycl_queue ,
2549
- depends = dep_evs ,
2550
- )
2551
- _manager .add_event_pair (ht_ev , piv_copy_ev )
2552
-
2553
2543
# SciPy-compatible behavior
2554
2544
# Copy is required if:
2555
2545
# - overwrite_b is False (always copy),
@@ -2567,14 +2557,11 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
2567
2557
depends = dep_evs ,
2568
2558
)
2569
2559
_manager .add_event_pair (ht_ev , b_copy_ev )
2570
- dep_evs = [lu_copy_ev , piv_copy_ev , b_copy_ev ]
2560
+ dep_evs = [lu_copy_ev , b_copy_ev ]
2571
2561
else :
2572
2562
# input is suitable for in-place modification
2573
2563
b_h = b
2574
- dep_evs = [lu_copy_ev , piv_copy_ev ]
2575
-
2576
- # MKL lapack uses 1-origin while SciPy uses 0-origin
2577
- piv_h += 1
2564
+ dep_evs = [lu_copy_ev ]
2578
2565
2579
2566
if not isinstance (trans , int ):
2580
2567
raise TypeError ("`trans` must be an integer" )
0 commit comments