Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/reference/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Decompositions
dpnp.linalg.cholesky
dpnp.linalg.outer
dpnp.linalg.qr
dpnp.linalg.lu_factor
dpnp.linalg.svd
dpnp.linalg.svdvals

Expand Down
27 changes: 14 additions & 13 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,38 +914,39 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
where `P` is a permutation matrix, `L` is lower triangular with unit
diagonal elements, and `U` is upper triangular.

For full documentation refer to :obj:`scipy.linalg.lu_factor`.

Parameters
----------
a : (M, N) {dpnp.ndarray, usm_ndarray}
a : (..., M, N) {dpnp.ndarray, usm_ndarray}
Input array to decompose.
overwrite_a : {None, bool}, optional
Whether to overwrite data in `a` (may increase performance)
Whether to overwrite data in `a` (may increase performance).

Default: ``False``.
check_finite : {None, bool}, optional
Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.

Default: ``True``.

Returns
-------
lu :(M, N) dpnp.ndarray
Matrix containing U in its upper triangle, and L in its lower triangle.
The unit diagonal elements of L are not stored.
piv (K, ): dpnp.ndarray
Pivot indices representing the permutation matrix P:
lu : (..., M, N) dpnp.ndarray
Matrix containing `U` in its upper triangle,
and `L` in its lower triangle.
The unit diagonal elements of `L` are not stored.
piv : (..., K) dpnp.ndarray
Pivot indices representing the permutation matrix `P`:
row i of matrix was interchanged with row piv[i].
``K = min(M, N)``.
Where ``K = min(M, N)``.

Warning
-------
This function synchronizes in order to validate array elements
when ``check_finite=True``.

Limitations
-----------
Only two-dimensional input matrices are supported.
Otherwise, the function raises ``NotImplementedError`` exception.

Examples
--------
>>> import dpnp as np
Expand Down
13 changes: 7 additions & 6 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
if any(dev_info_h):
diag_nums = ", ".join(str(v) for v in dev_info_h if v > 0)
warn(
f"Diagonal number {diag_nums} are exactly zero. "
f"Diagonal numbers {diag_nums} are exactly zero. "
"Singular matrix.",
RuntimeWarning,
stacklevel=2,
Expand Down Expand Up @@ -2463,17 +2463,18 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
# - not writeable
if not overwrite_a or _is_copy_required(a, res_type):
a_h = dpnp.empty_like(a, order="F", dtype=res_type)
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
ht_ev, dep_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=a_usm_arr,
dst=a_h.get_array(),
sycl_queue=a_sycl_queue,
depends=_manager.submitted_events,
)
_manager.add_event_pair(ht_ev, copy_ev)
_manager.add_event_pair(ht_ev, dep_ev)
dep_ev = [dep_ev]
else:
# input is suitable for in-place modification
a_h = a
copy_ev = None
dep_ev = _manager.submitted_events

m, n = a.shape

Expand All @@ -2493,14 +2494,14 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
a_h.get_array(),
ipiv_h.get_array(),
dev_info_h,
depends=[copy_ev] if copy_ev is not None else [],
depends=dep_ev,
)
_manager.add_event_pair(ht_ev, getrf_ev)

if any(dev_info_h):
diag_nums = ", ".join(str(v) for v in dev_info_h if v > 0)
warn(
f"Diagonal number {diag_nums} are exactly zero. Singular matrix.",
f"Diagonal number {diag_nums} is exactly zero. Singular matrix.",
RuntimeWarning,
stacklevel=2,
)
Expand Down
18 changes: 9 additions & 9 deletions dpnp/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1911,7 +1911,7 @@ def test_lu_factor(self, shape, order, dtype):
A_cast = a_dp.astype(LU.dtype, copy=False)
PA = self._apply_pivots_rows(A_cast, piv)

assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
def test_overwrite_inplace(self, dtype):
Expand All @@ -1928,7 +1928,7 @@ def test_overwrite_inplace(self, dtype):
PA = self._apply_pivots_rows(a_dp_orig, piv)
LU = L @ U

assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
def test_overwrite_copy(self, dtype):
Expand All @@ -1945,7 +1945,7 @@ def test_overwrite_copy(self, dtype):
PA = self._apply_pivots_rows(a_dp_orig, piv)
LU = L @ U

assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)

def test_overwrite_copy_special(self):
# F-contig but dtype != res_type
Expand All @@ -1972,7 +1972,7 @@ def test_overwrite_copy_special(self):
a_orig.astype(L.dtype, copy=False), piv
)
LU = L @ U
assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize("shape", [(0, 0), (0, 2), (2, 0)])
def test_empty_inputs(self, shape):
Expand Down Expand Up @@ -2003,7 +2003,7 @@ def test_strided(self, sl):
PA = self._apply_pivots_rows(a_dp, piv)
LU = L @ U

assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)

def test_singular_matrix(self):
a_dp = dpnp.array([[1.0, 2.0], [2.0, 4.0]])
Expand Down Expand Up @@ -2070,7 +2070,7 @@ def test_lu_factor_batched(self, shape, order, dtype):
L, U = self._split_lu(lu_3d[i], m, n)
A_cast = a_3d[i].astype(L.dtype, copy=False)
PA = self._apply_pivots_rows(A_cast, piv_2d[i])
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(L @ U, PA, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
@pytest.mark.parametrize("order", ["C", "F"])
Expand All @@ -2082,7 +2082,7 @@ def test_overwrite(self, dtype, order):
)

assert lu is not a_dp
assert_allclose(a_dp, a_dp_orig)
assert dpnp.allclose(a_dp, a_dp_orig)

m = n = 2
lu_3d = lu.reshape((-1, m, n))
Expand All @@ -2092,7 +2092,7 @@ def test_overwrite(self, dtype, order):
L, U = self._split_lu(lu_3d[i], m, n)
A_cast = a_3d[i].astype(L.dtype, copy=False)
PA = self._apply_pivots_rows(A_cast, piv_2d[i])
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(L @ U, PA, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize(
"shape", [(0, 2, 2), (2, 0, 2), (2, 2, 0), (0, 0, 0)]
Expand All @@ -2119,7 +2119,7 @@ def test_strided(self):
PA = self._apply_pivots_rows(
a_stride[i].astype(L.dtype, copy=False), piv[i]
)
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
assert dpnp.allclose(L @ U, PA, rtol=1e-6, atol=1e-6)

def test_singular_matrix(self):
a = dpnp.zeros((3, 2, 2), dtype=dpnp.default_float_type())
Expand Down
Loading