Skip to content

Commit ccbfccf

Browse files
Solve race conditions issue for pivots
1 parent 52eac3d commit ccbfccf

File tree

1 file changed

+6
-19
lines changed

1 file changed

+6
-19
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2518,9 +2518,12 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
25182518
)
25192519

25202520
lu_usm_arr = dpnp.get_usm_ndarray(lu)
2521-
piv_usm_arr = dpnp.get_usm_ndarray(piv)
25222521
b_usm_arr = dpnp.get_usm_ndarray(b)
25232522

2523+
# dpnp.linalg.lu_factor() returns 0-based pivots to match SciPy,
2524+
# convert to 1-based for oneMKL getrs
2525+
piv_h = piv + 1
2526+
25242527
_manager = dpu.SequentialOrderManager[exec_q]
25252528
dep_evs = _manager.submitted_events
25262529

@@ -2537,19 +2540,6 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
25372540
)
25382541
_manager.add_event_pair(ht_ev, lu_copy_ev)
25392542

2540-
# oneMKL LAPACK getrf overwrites `piv`.
2541-
piv_h = dpnp.empty_like(piv, order="F", usm_type=res_usm_type)
2542-
2543-
# use DPCTL tensor function to fill the сopy of the pivot array
2544-
# from the pivot array
2545-
ht_ev, piv_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
2546-
src=piv_usm_arr,
2547-
dst=piv_h.get_array(),
2548-
sycl_queue=piv.sycl_queue,
2549-
depends=dep_evs,
2550-
)
2551-
_manager.add_event_pair(ht_ev, piv_copy_ev)
2552-
25532543
# SciPy-compatible behavior
25542544
# Copy is required if:
25552545
# - overwrite_b is False (always copy),
@@ -2567,14 +2557,11 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
25672557
depends=dep_evs,
25682558
)
25692559
_manager.add_event_pair(ht_ev, b_copy_ev)
2570-
dep_evs = [lu_copy_ev, piv_copy_ev, b_copy_ev]
2560+
dep_evs = [lu_copy_ev, b_copy_ev]
25712561
else:
25722562
# input is suitable for in-place modification
25732563
b_h = b
2574-
dep_evs = [lu_copy_ev, piv_copy_ev]
2575-
2576-
# MKL lapack uses 1-origin while SciPy uses 0-origin
2577-
piv_h += 1
2564+
dep_evs = [lu_copy_ev]
25782565

25792566
if not isinstance(trans, int):
25802567
raise TypeError("`trans` must be an integer")

0 commit comments

Comments
 (0)