Skip to content

Commit 55625eb

Browse files
Apply remarks
1 parent 6e7b9f7 commit 55625eb

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -914,9 +914,11 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
914914
where `P` is a permutation matrix, `L` is lower triangular with unit
915915
diagonal elements, and `U` is upper triangular.
916916
917+
For full documentation refer to :obj:`scipy.linalg.lu_factor`.
918+
917919
Parameters
918920
----------
919-
a : (M, N) {dpnp.ndarray, usm_ndarray}
921+
a : (..., M, N) {dpnp.ndarray, usm_ndarray}
920922
Input array to decompose.
921923
overwrite_a : {None, bool}, optional
922924
Whether to overwrite data in `a` (may increase performance).
@@ -931,13 +933,14 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
931933
932934
Returns
933935
-------
934-
lu : (M, N) dpnp.ndarray
935-
Matrix containing U in its upper triangle, and L in its lower triangle.
936-
The unit diagonal elements of L are not stored.
937-
piv : (K, ) dpnp.ndarray
938-
Pivot indices representing the permutation matrix P:
936+
lu : (..., M, N) dpnp.ndarray
937+
Matrix containing `U` in its upper triangle,
938+
and `L` in its lower triangle.
939+
The unit diagonal elements of `L` are not stored.
940+
piv : (..., K) dpnp.ndarray
941+
Pivot indices representing the permutation matrix `P`:
939942
row i of matrix was interchanged with row piv[i].
940-
``K = min(M, N)``.
943+
Where ``K = min(M, N)``.
941944
942945
Warning
943946
-------

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2463,17 +2463,17 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
24632463
# - not writeable
24642464
if not overwrite_a or _is_copy_required(a, res_type):
24652465
a_h = dpnp.empty_like(a, order="F", dtype=res_type)
2466-
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
2466+
ht_ev, dep_ev = ti._copy_usm_ndarray_into_usm_ndarray(
24672467
src=a_usm_arr,
24682468
dst=a_h.get_array(),
24692469
sycl_queue=a_sycl_queue,
24702470
depends=_manager.submitted_events,
24712471
)
2472-
_manager.add_event_pair(ht_ev, copy_ev)
2472+
_manager.add_event_pair(ht_ev, dep_ev)
24732473
else:
24742474
# input is suitable for in-place modification
24752475
a_h = a
2476-
copy_ev = None
2476+
dep_ev = _manager.submitted_events
24772477

24782478
m, n = a.shape
24792479

@@ -2493,7 +2493,7 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
24932493
a_h.get_array(),
24942494
ipiv_h.get_array(),
24952495
dev_info_h,
2496-
depends=[copy_ev] if copy_ev is not None else _manager.submitted_events,
2496+
depends=dep_ev,
24972497
)
24982498
_manager.add_event_pair(ht_ev, getrf_ev)
24992499

0 commit comments

Comments
 (0)