Skip to content

Commit 1e09cb7

Browse files
Update dependency logic
1 parent be2725a commit 1e09cb7

File tree

1 file changed

+19
-18
lines changed

1 file changed

+19
-18
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2572,6 +2572,19 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
25722572
)
25732573
_manager.add_event_pair(ht_ev, lu_copy_ev)
25742574

2575+
# oneMKL LAPACK getrf overwrites `piv`.
2576+
piv_h = dpnp.empty_like(piv, order="F", usm_type=res_usm_type)
2577+
2578+
# use DPCTL tensor function to fill the сopy of the pivot array
2579+
# from the pivot array
2580+
ht_ev, piv_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
2581+
src=piv_usm_arr,
2582+
dst=piv_h.get_array(),
2583+
sycl_queue=piv.sycl_queue,
2584+
depends=dep_evs,
2585+
)
2586+
_manager.add_event_pair(ht_ev, piv_copy_ev)
2587+
25752588
# SciPy-compatible behavior
25762589
# Copy is required if:
25772590
# - overwrite_b is False (always copy),
@@ -2582,31 +2595,19 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
25822595
b_h = dpnp.empty_like(
25832596
b, order="F", dtype=res_type, usm_type=res_usm_type
25842597
)
2585-
ht_ev, dep_ev = ti._copy_usm_ndarray_into_usm_ndarray(
2598+
ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
25862599
src=b_usm_arr,
25872600
dst=b_h.get_array(),
25882601
sycl_queue=b.sycl_queue,
2589-
depends=_manager.submitted_events,
2602+
depends=dep_evs,
25902603
)
2591-
_manager.add_event_pair(ht_ev, dep_ev)
2592-
dep_ev = [dep_ev]
2604+
_manager.add_event_pair(ht_ev, b_copy_ev)
2605+
dep_evs = [lu_copy_ev, piv_copy_ev, b_copy_ev]
25932606
else:
25942607
# input is suitable for in-place modification
25952608
b_h = b
2596-
dep_ev = _manager.submitted_events
2597-
2598-
# oneMKL LAPACK getrf overwrites `piv`.
2599-
piv_h = dpnp.empty_like(piv, order="F", usm_type=res_usm_type)
2609+
dep_evs = [lu_copy_ev, piv_copy_ev]
26002610

2601-
# use DPCTL tensor function to fill the сopy of the pivot array
2602-
# from the pivot array
2603-
ht_ev, piv_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
2604-
src=piv_usm_arr,
2605-
dst=piv_h.get_array(),
2606-
sycl_queue=piv.sycl_queue,
2607-
depends=dep_evs,
2608-
)
2609-
_manager.add_event_pair(ht_ev, piv_copy_ev)
26102611
# MKL lapack uses 1-origin while SciPy uses 0-origin
26112612
piv_h += 1
26122613

@@ -2619,7 +2620,7 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
26192620
piv_h.get_array(),
26202621
b_h.get_array(),
26212622
trans,
2622-
depends=dep_ev,
2623+
depends=dep_evs,
26232624
)
26242625
_manager.add_event_pair(ht_ev, getrs_ev)
26252626

0 commit comments

Comments
 (0)