Skip to content

Commit 5584348

Browse files
Minor updates dpnp.linalg.lu_factor() (#2570)
This PR suggests minor upgrades for `dpnp.linalg.lu_factor()` including docstring fixes, clearer warnings, improved dependency handling, adding a function to documentation generation and updating tests to use `dpnp.allclose` instead of `assert_allclose`
1 parent 3c8ac08 commit 5584348

File tree

4 files changed

+31
-28
lines changed

4 files changed

+31
-28
lines changed

doc/reference/linalg.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ Decompositions
4343
dpnp.linalg.cholesky
4444
dpnp.linalg.outer
4545
dpnp.linalg.qr
46+
dpnp.linalg.lu_factor
4647
dpnp.linalg.svd
4748
dpnp.linalg.svdvals
4849

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -914,38 +914,39 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
914914
where `P` is a permutation matrix, `L` is lower triangular with unit
915915
diagonal elements, and `U` is upper triangular.
916916
917+
For full documentation refer to :obj:`scipy.linalg.lu_factor`.
918+
917919
Parameters
918920
----------
919-
a : (M, N) {dpnp.ndarray, usm_ndarray}
921+
a : (..., M, N) {dpnp.ndarray, usm_ndarray}
920922
Input array to decompose.
921923
overwrite_a : {None, bool}, optional
922-
Whether to overwrite data in `a` (may increase performance)
924+
Whether to overwrite data in `a` (may increase performance).
925+
923926
Default: ``False``.
924927
check_finite : {None, bool}, optional
925928
Whether to check that the input matrix contains only finite numbers.
926929
Disabling may give a performance gain, but may result in problems
927930
(crashes, non-termination) if the inputs do contain infinities or NaNs.
928931
932+
Default: ``True``.
933+
929934
Returns
930935
-------
931-
lu :(M, N) dpnp.ndarray
932-
Matrix containing U in its upper triangle, and L in its lower triangle.
933-
The unit diagonal elements of L are not stored.
934-
piv (K, ): dpnp.ndarray
935-
Pivot indices representing the permutation matrix P:
936+
lu : (..., M, N) dpnp.ndarray
937+
Matrix containing `U` in its upper triangle,
938+
and `L` in its lower triangle.
939+
The unit diagonal elements of `L` are not stored.
940+
piv : (..., K) dpnp.ndarray
941+
Pivot indices representing the permutation matrix `P`:
936942
row i of matrix was interchanged with row piv[i].
937-
``K = min(M, N)``.
943+
Where ``K = min(M, N)``.
938944
939945
Warning
940946
-------
941947
This function synchronizes in order to validate array elements
942948
when ``check_finite=True``.
943949
944-
Limitations
945-
-----------
946-
Only two-dimensional input matrices are supported.
947-
Otherwise, the function raises ``NotImplementedError`` exception.
948-
949950
Examples
950951
--------
951952
>>> import dpnp as np

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
481481
if any(dev_info_h):
482482
diag_nums = ", ".join(str(v) for v in dev_info_h if v > 0)
483483
warn(
484-
f"Diagonal number {diag_nums} are exactly zero. "
484+
f"Diagonal numbers {diag_nums} are exactly zero. "
485485
"Singular matrix.",
486486
RuntimeWarning,
487487
stacklevel=2,
@@ -2463,17 +2463,18 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
24632463
# - not writeable
24642464
if not overwrite_a or _is_copy_required(a, res_type):
24652465
a_h = dpnp.empty_like(a, order="F", dtype=res_type)
2466-
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
2466+
ht_ev, dep_ev = ti._copy_usm_ndarray_into_usm_ndarray(
24672467
src=a_usm_arr,
24682468
dst=a_h.get_array(),
24692469
sycl_queue=a_sycl_queue,
24702470
depends=_manager.submitted_events,
24712471
)
2472-
_manager.add_event_pair(ht_ev, copy_ev)
2472+
_manager.add_event_pair(ht_ev, dep_ev)
2473+
dep_ev = [dep_ev]
24732474
else:
24742475
# input is suitable for in-place modification
24752476
a_h = a
2476-
copy_ev = None
2477+
dep_ev = _manager.submitted_events
24772478

24782479
m, n = a.shape
24792480

@@ -2493,14 +2494,14 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
24932494
a_h.get_array(),
24942495
ipiv_h.get_array(),
24952496
dev_info_h,
2496-
depends=[copy_ev] if copy_ev is not None else [],
2497+
depends=dep_ev,
24972498
)
24982499
_manager.add_event_pair(ht_ev, getrf_ev)
24992500

25002501
if any(dev_info_h):
25012502
diag_nums = ", ".join(str(v) for v in dev_info_h if v > 0)
25022503
warn(
2503-
f"Diagonal number {diag_nums} are exactly zero. Singular matrix.",
2504+
f"Diagonal number {diag_nums} is exactly zero. Singular matrix.",
25042505
RuntimeWarning,
25052506
stacklevel=2,
25062507
)

dpnp/tests/test_linalg.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1911,7 +1911,7 @@ def test_lu_factor(self, shape, order, dtype):
19111911
A_cast = a_dp.astype(LU.dtype, copy=False)
19121912
PA = self._apply_pivots_rows(A_cast, piv)
19131913

1914-
assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
1914+
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)
19151915

19161916
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
19171917
def test_overwrite_inplace(self, dtype):
@@ -1928,7 +1928,7 @@ def test_overwrite_inplace(self, dtype):
19281928
PA = self._apply_pivots_rows(a_dp_orig, piv)
19291929
LU = L @ U
19301930

1931-
assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
1931+
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)
19321932

19331933
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
19341934
def test_overwrite_copy(self, dtype):
@@ -1945,7 +1945,7 @@ def test_overwrite_copy(self, dtype):
19451945
PA = self._apply_pivots_rows(a_dp_orig, piv)
19461946
LU = L @ U
19471947

1948-
assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
1948+
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)
19491949

19501950
def test_overwrite_copy_special(self):
19511951
# F-contig but dtype != res_type
@@ -1972,7 +1972,7 @@ def test_overwrite_copy_special(self):
19721972
a_orig.astype(L.dtype, copy=False), piv
19731973
)
19741974
LU = L @ U
1975-
assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
1975+
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)
19761976

19771977
@pytest.mark.parametrize("shape", [(0, 0), (0, 2), (2, 0)])
19781978
def test_empty_inputs(self, shape):
@@ -2003,7 +2003,7 @@ def test_strided(self, sl):
20032003
PA = self._apply_pivots_rows(a_dp, piv)
20042004
LU = L @ U
20052005

2006-
assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
2006+
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)
20072007

20082008
def test_singular_matrix(self):
20092009
a_dp = dpnp.array([[1.0, 2.0], [2.0, 4.0]])
@@ -2070,7 +2070,7 @@ def test_lu_factor_batched(self, shape, order, dtype):
20702070
L, U = self._split_lu(lu_3d[i], m, n)
20712071
A_cast = a_3d[i].astype(L.dtype, copy=False)
20722072
PA = self._apply_pivots_rows(A_cast, piv_2d[i])
2073-
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
2073+
assert dpnp.allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
20742074

20752075
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
20762076
@pytest.mark.parametrize("order", ["C", "F"])
@@ -2082,7 +2082,7 @@ def test_overwrite(self, dtype, order):
20822082
)
20832083

20842084
assert lu is not a_dp
2085-
assert_allclose(a_dp, a_dp_orig)
2085+
assert dpnp.allclose(a_dp, a_dp_orig)
20862086

20872087
m = n = 2
20882088
lu_3d = lu.reshape((-1, m, n))
@@ -2092,7 +2092,7 @@ def test_overwrite(self, dtype, order):
20922092
L, U = self._split_lu(lu_3d[i], m, n)
20932093
A_cast = a_3d[i].astype(L.dtype, copy=False)
20942094
PA = self._apply_pivots_rows(A_cast, piv_2d[i])
2095-
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
2095+
assert dpnp.allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
20962096

20972097
@pytest.mark.parametrize(
20982098
"shape", [(0, 2, 2), (2, 0, 2), (2, 2, 0), (0, 0, 0)]
@@ -2119,7 +2119,7 @@ def test_strided(self):
21192119
PA = self._apply_pivots_rows(
21202120
a_stride[i].astype(L.dtype, copy=False), piv[i]
21212121
)
2122-
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
2122+
assert dpnp.allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
21232123

21242124
def test_singular_matrix(self):
21252125
a = dpnp.zeros((3, 2, 2), dtype=dpnp.default_float_type())

0 commit comments

Comments
 (0)