Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/reference/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Decompositions
dpnp.linalg.cholesky
dpnp.linalg.outer
dpnp.linalg.qr
dpnp.linalg.lu_factor
dpnp.linalg.svd
dpnp.linalg.svdvals

Expand Down
22 changes: 14 additions & 8 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,27 +914,33 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
where `P` is a permutation matrix, `L` is lower triangular with unit
diagonal elements, and `U` is upper triangular.

For full documentation refer to :obj:`scipy.linalg.lu_factor`.

Parameters
----------
a : (M, N) {dpnp.ndarray, usm_ndarray}
a : (..., M, N) {dpnp.ndarray, usm_ndarray}
Input array to decompose.
overwrite_a : {None, bool}, optional
Whether to overwrite data in `a` (may increase performance)
Whether to overwrite data in `a` (may increase performance).

Default: ``False``.
check_finite : {None, bool}, optional
Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.

Default: ``True``.

Returns
-------
lu :(M, N) dpnp.ndarray
Matrix containing U in its upper triangle, and L in its lower triangle.
The unit diagonal elements of L are not stored.
piv (K, ): dpnp.ndarray
Pivot indices representing the permutation matrix P:
lu : (..., M, N) dpnp.ndarray
Matrix containing `U` in its upper triangle,
and `L` in its lower triangle.
The unit diagonal elements of `L` are not stored.
piv : (..., K) dpnp.ndarray
Pivot indices representing the permutation matrix `P`:
row i of matrix was interchanged with row piv[i].
``K = min(M, N)``.
Where ``K = min(M, N)``.

Warning
-------
Expand Down
12 changes: 6 additions & 6 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
if any(dev_info_h):
diag_nums = ", ".join(str(v) for v in dev_info_h if v > 0)
warn(
f"Diagonal number {diag_nums} are exactly zero. "
f"Diagonal numbers {diag_nums} are exactly zero. "
"Singular matrix.",
RuntimeWarning,
stacklevel=2,
Expand Down Expand Up @@ -2463,17 +2463,17 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
# - not writeable
if not overwrite_a or _is_copy_required(a, res_type):
a_h = dpnp.empty_like(a, order="F", dtype=res_type)
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
ht_ev, dep_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=a_usm_arr,
dst=a_h.get_array(),
sycl_queue=a_sycl_queue,
depends=_manager.submitted_events,
)
_manager.add_event_pair(ht_ev, copy_ev)
_manager.add_event_pair(ht_ev, dep_ev)
else:
# input is suitable for in-place modification
a_h = a
copy_ev = None
dep_ev = _manager.submitted_events

m, n = a.shape

Expand All @@ -2493,14 +2493,14 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
a_h.get_array(),
ipiv_h.get_array(),
dev_info_h,
depends=[copy_ev] if copy_ev is not None else [],
depends=dep_ev,
)
_manager.add_event_pair(ht_ev, getrf_ev)

if any(dev_info_h):
diag_nums = ", ".join(str(v) for v in dev_info_h if v > 0)
warn(
f"Diagonal number {diag_nums} are exactly zero. Singular matrix.",
f"Diagonal number {diag_nums} is exactly zero. Singular matrix.",
RuntimeWarning,
stacklevel=2,
)
Expand Down
18 changes: 9 additions & 9 deletions dpnp/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1911,7 +1911,7 @@ def test_lu_factor(self, shape, order, dtype):
A_cast = a_dp.astype(LU.dtype, copy=False)
PA = self._apply_pivots_rows(A_cast, piv)

assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
def test_overwrite_inplace(self, dtype):
Expand All @@ -1928,7 +1928,7 @@ def test_overwrite_inplace(self, dtype):
PA = self._apply_pivots_rows(a_dp_orig, piv)
LU = L @ U

assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
def test_overwrite_copy(self, dtype):
Expand All @@ -1945,7 +1945,7 @@ def test_overwrite_copy(self, dtype):
PA = self._apply_pivots_rows(a_dp_orig, piv)
LU = L @ U

assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)

def test_overwrite_copy_special(self):
# F-contig but dtype != res_type
Expand All @@ -1972,7 +1972,7 @@ def test_overwrite_copy_special(self):
a_orig.astype(L.dtype, copy=False), piv
)
LU = L @ U
assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize("shape", [(0, 0), (0, 2), (2, 0)])
def test_empty_inputs(self, shape):
Expand Down Expand Up @@ -2003,7 +2003,7 @@ def test_strided(self, sl):
PA = self._apply_pivots_rows(a_dp, piv)
LU = L @ U

assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)

def test_singular_matrix(self):
a_dp = dpnp.array([[1.0, 2.0], [2.0, 4.0]])
Expand Down Expand Up @@ -2070,7 +2070,7 @@ def test_lu_factor_batched(self, shape, order, dtype):
L, U = self._split_lu(lu_3d[i], m, n)
A_cast = a_3d[i].astype(L.dtype, copy=False)
PA = self._apply_pivots_rows(A_cast, piv_2d[i])
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(L @ U, PA, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
@pytest.mark.parametrize("order", ["C", "F"])
Expand All @@ -2082,7 +2082,7 @@ def test_overwrite(self, dtype, order):
)

assert lu is not a_dp
assert_allclose(a_dp, a_dp_orig)
assert dpnp.allclose(a_dp, a_dp_orig)

m = n = 2
lu_3d = lu.reshape((-1, m, n))
Expand All @@ -2092,7 +2092,7 @@ def test_overwrite(self, dtype, order):
L, U = self._split_lu(lu_3d[i], m, n)
A_cast = a_3d[i].astype(L.dtype, copy=False)
PA = self._apply_pivots_rows(A_cast, piv_2d[i])
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(L @ U, PA, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize(
"shape", [(0, 2, 2), (2, 0, 2), (2, 2, 0), (0, 0, 0)]
Expand All @@ -2119,7 +2119,7 @@ def test_strided(self):
PA = self._apply_pivots_rows(
a_stride[i].astype(L.dtype, copy=False), piv[i]
)
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(L @ U, PA, rtol=1e-6, atol=1e-6)

def test_singular_matrix(self):
a = dpnp.zeros((3, 2, 2), dtype=dpnp.default_float_type())
Expand Down
Loading