Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/reference/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Decompositions
dpnp.linalg.cholesky
dpnp.linalg.outer
dpnp.linalg.qr
dpnp.linalg.lu_factor
dpnp.linalg.svd
dpnp.linalg.svdvals

Expand Down
9 changes: 6 additions & 3 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,19 +919,22 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
a : (M, N) {dpnp.ndarray, usm_ndarray}
Input array to decompose.
overwrite_a : {None, bool}, optional
Whether to overwrite data in `a` (may increase performance)
Whether to overwrite data in `a` (may increase performance).

Default: ``False``.
check_finite : {None, bool}, optional
Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.

Default: ``True``.

Returns
-------
lu :(M, N) dpnp.ndarray
lu : (M, N) dpnp.ndarray
Matrix containing U in its upper triangle, and L in its lower triangle.
The unit diagonal elements of L are not stored.
piv (K, ): dpnp.ndarray
piv : (K, ) dpnp.ndarray
Pivot indices representing the permutation matrix P:
row i of matrix was interchanged with row piv[i].
``K = min(M, N)``.
Expand Down
6 changes: 3 additions & 3 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
if any(dev_info_h):
diag_nums = ", ".join(str(v) for v in dev_info_h if v > 0)
warn(
f"Diagonal number {diag_nums} are exactly zero. "
f"Diagonal numbers {diag_nums} are exactly zero. "
"Singular matrix.",
RuntimeWarning,
stacklevel=2,
Expand Down Expand Up @@ -2493,14 +2493,14 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
a_h.get_array(),
ipiv_h.get_array(),
dev_info_h,
depends=[copy_ev] if copy_ev is not None else [],
depends=[copy_ev] if copy_ev is not None else _manager.submitted_events,
)
_manager.add_event_pair(ht_ev, getrf_ev)

if any(dev_info_h):
diag_nums = ", ".join(str(v) for v in dev_info_h if v > 0)
warn(
f"Diagonal number {diag_nums} are exactly zero. Singular matrix.",
f"Diagonal number {diag_nums} is exactly zero. Singular matrix.",
RuntimeWarning,
stacklevel=2,
)
Expand Down
18 changes: 9 additions & 9 deletions dpnp/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1911,7 +1911,7 @@ def test_lu_factor(self, shape, order, dtype):
A_cast = a_dp.astype(LU.dtype, copy=False)
PA = self._apply_pivots_rows(A_cast, piv)

assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
def test_overwrite_inplace(self, dtype):
Expand All @@ -1928,7 +1928,7 @@ def test_overwrite_inplace(self, dtype):
PA = self._apply_pivots_rows(a_dp_orig, piv)
LU = L @ U

assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
def test_overwrite_copy(self, dtype):
Expand All @@ -1945,7 +1945,7 @@ def test_overwrite_copy(self, dtype):
PA = self._apply_pivots_rows(a_dp_orig, piv)
LU = L @ U

assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)

def test_overwrite_copy_special(self):
# F-contig but dtype != res_type
Expand All @@ -1972,7 +1972,7 @@ def test_overwrite_copy_special(self):
a_orig.astype(L.dtype, copy=False), piv
)
LU = L @ U
assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize("shape", [(0, 0), (0, 2), (2, 0)])
def test_empty_inputs(self, shape):
Expand Down Expand Up @@ -2003,7 +2003,7 @@ def test_strided(self, sl):
PA = self._apply_pivots_rows(a_dp, piv)
LU = L @ U

assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)

def test_singular_matrix(self):
a_dp = dpnp.array([[1.0, 2.0], [2.0, 4.0]])
Expand Down Expand Up @@ -2070,7 +2070,7 @@ def test_lu_factor_batched(self, shape, order, dtype):
L, U = self._split_lu(lu_3d[i], m, n)
A_cast = a_3d[i].astype(L.dtype, copy=False)
PA = self._apply_pivots_rows(A_cast, piv_2d[i])
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(L @ U, PA, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
@pytest.mark.parametrize("order", ["C", "F"])
Expand All @@ -2082,7 +2082,7 @@ def test_overwrite(self, dtype, order):
)

assert lu is not a_dp
assert_allclose(a_dp, a_dp_orig)
assert dpnp.allclose(a_dp, a_dp_orig)

m = n = 2
lu_3d = lu.reshape((-1, m, n))
Expand All @@ -2092,7 +2092,7 @@ def test_overwrite(self, dtype, order):
L, U = self._split_lu(lu_3d[i], m, n)
A_cast = a_3d[i].astype(L.dtype, copy=False)
PA = self._apply_pivots_rows(A_cast, piv_2d[i])
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(L @ U, PA, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize(
"shape", [(0, 2, 2), (2, 0, 2), (2, 2, 0), (0, 0, 0)]
Expand All @@ -2119,7 +2119,7 @@ def test_strided(self):
PA = self._apply_pivots_rows(
a_stride[i].astype(L.dtype, copy=False), piv[i]
)
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(L @ U, PA, rtol=1e-6, atol=1e-6)

def test_singular_matrix(self):
a = dpnp.zeros((3, 2, 2), dtype=dpnp.default_float_type())
Expand Down
Loading