@@ -2572,6 +2572,19 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
2572
2572
)
2573
2573
_manager .add_event_pair (ht_ev , lu_copy_ev )
2574
2574
2575
+ # oneMKL LAPACK getrf overwrites `piv`.
2576
+ piv_h = dpnp .empty_like (piv , order = "F" , usm_type = res_usm_type )
2577
+
2578
+ # use DPCTL tensor function to fill the сopy of the pivot array
2579
+ # from the pivot array
2580
+ ht_ev , piv_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
2581
+ src = piv_usm_arr ,
2582
+ dst = piv_h .get_array (),
2583
+ sycl_queue = piv .sycl_queue ,
2584
+ depends = dep_evs ,
2585
+ )
2586
+ _manager .add_event_pair (ht_ev , piv_copy_ev )
2587
+
2575
2588
# SciPy-compatible behavior
2576
2589
# Copy is required if:
2577
2590
# - overwrite_b is False (always copy),
@@ -2582,31 +2595,19 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
2582
2595
b_h = dpnp .empty_like (
2583
2596
b , order = "F" , dtype = res_type , usm_type = res_usm_type
2584
2597
)
2585
- ht_ev , dep_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
2598
+ ht_ev , b_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
2586
2599
src = b_usm_arr ,
2587
2600
dst = b_h .get_array (),
2588
2601
sycl_queue = b .sycl_queue ,
2589
- depends = _manager . submitted_events ,
2602
+ depends = dep_evs ,
2590
2603
)
2591
- _manager .add_event_pair (ht_ev , dep_ev )
2592
- dep_ev = [dep_ev ]
2604
+ _manager .add_event_pair (ht_ev , b_copy_ev )
2605
+ dep_evs = [lu_copy_ev , piv_copy_ev , b_copy_ev ]
2593
2606
else :
2594
2607
# input is suitable for in-place modification
2595
2608
b_h = b
2596
- dep_ev = _manager .submitted_events
2597
-
2598
- # oneMKL LAPACK getrf overwrites `piv`.
2599
- piv_h = dpnp .empty_like (piv , order = "F" , usm_type = res_usm_type )
2609
+ dep_evs = [lu_copy_ev , piv_copy_ev ]
2600
2610
2601
- # use DPCTL tensor function to fill the сopy of the pivot array
2602
- # from the pivot array
2603
- ht_ev , piv_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
2604
- src = piv_usm_arr ,
2605
- dst = piv_h .get_array (),
2606
- sycl_queue = piv .sycl_queue ,
2607
- depends = dep_evs ,
2608
- )
2609
- _manager .add_event_pair (ht_ev , piv_copy_ev )
2610
2611
# MKL lapack uses 1-origin while SciPy uses 0-origin
2611
2612
piv_h += 1
2612
2613
@@ -2619,7 +2620,7 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
2619
2620
piv_h .get_array (),
2620
2621
b_h .get_array (),
2621
2622
trans ,
2622
- depends = dep_ev ,
2623
+ depends = dep_evs ,
2623
2624
)
2624
2625
_manager .add_event_pair (ht_ev , getrs_ev )
2625
2626
0 commit comments