Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4223c72
Pass trans_code to getrs in dpnp_solve()
vlad-perevezentsev Aug 18, 2025
80ce50c
Remove TODO
vlad-perevezentsev Sep 4, 2025
af0ab7d
Implement of dpnp.linalg.lu_solve for 2D inputs
vlad-perevezentsev Sep 4, 2025
17b11ae
Add dpnp.linalg.lu_solve to generated docs
vlad-perevezentsev Sep 4, 2025
b10a8d6
Add TestLuSolve to test_linalg.py
vlad-perevezentsev Sep 4, 2025
2021f77
Add sycl_queue and usm_type tests
vlad-perevezentsev Sep 4, 2025
be2725a
Update doc/comment lines
vlad-perevezentsev Sep 16, 2025
1e09cb7
Update dependency logic
vlad-perevezentsev Sep 18, 2025
9345b7b
Add trans code handling
vlad-perevezentsev Sep 18, 2025
687006f
Fix docs for lu:must be square
vlad-perevezentsev Sep 18, 2025
b1aed58
Merge master into impl_lu_solve_2D
vlad-perevezentsev Sep 18, 2025
9aaff82
Update changelog
vlad-perevezentsev Sep 18, 2025
23ad15d
Apply docs remarks
vlad-perevezentsev Sep 19, 2025
82de136
Apply remarks
vlad-perevezentsev Sep 19, 2025
e586075
Add assert on USM data pointer to tests
vlad-perevezentsev Sep 19, 2025
7d1fd0b
Update data inputs for test_usm_type
vlad-perevezentsev Sep 19, 2025
87074fa
Add See Also to lu_factor
vlad-perevezentsev Sep 22, 2025
d81454e
Enable cupyx tests
vlad-perevezentsev Sep 22, 2025
4c9afa9
Merge master into impl_lu_solve_2D
vlad-perevezentsev Sep 22, 2025
78a4c78
Adjust tolerance for test_lu_solve
vlad-perevezentsev Sep 22, 2025
27649a7
Merge master into impl_lu_solve_2D
vlad-perevezentsev Sep 22, 2025
d0fbd49
Apply remark
vlad-perevezentsev Sep 22, 2025
52eac3d
Adjust tolerance for interger dtypes
vlad-perevezentsev Sep 22, 2025
ccbfccf
Solve race conditions issue for pivots
vlad-perevezentsev Sep 23, 2025
e91c24a
Revert the tolerance adjustment
vlad-perevezentsev Sep 23, 2025
0a9d817
Merge master into impl_lu_solve_2D
vlad-perevezentsev Sep 23, 2025
dfc76e0
Apply remark
vlad-perevezentsev Sep 24, 2025
3127190
Adjust tolerance for interger dtypes
vlad-perevezentsev Sep 24, 2025
05907b1
Enable test_broadcast_rhs
vlad-perevezentsev Sep 24, 2025
56fea1e
Make test more stable by adjusting tolerance
vlad-perevezentsev Sep 24, 2025
5dda424
Merge master into impl_lu_solve_2D
vlad-perevezentsev Sep 24, 2025
ca225d8
Merge branch 'master' into impl_lu_solve_2D
antonwolfy Sep 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,7 @@ def lu_factor(a, overwrite_a=False, check_finite=True):

def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
"""
Solve an equation system, a x = b, given the LU factorization of `a`.
Solve a linear system, :math:`a x = b`, given the LU factorization of `a`.

For full documentation refer to :obj:`scipy.linalg.lu_solve`.

Expand All @@ -983,13 +983,15 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
trans : {0, 1, 2} , optional
Type of system to solve:

===== =========
===== =================
trans system
===== =========
0 a x = b
1 a^T x = b
2 a^H x = b
===== =========
===== =================
0 :math:`a x = b`
1 :math:`a^T x = b`
2 :math:`a^H x = b`
===== =================

Default: ``0``.
overwrite_b : {None, bool}, optional
Whether to overwrite data in `b` (may increase performance).

Expand All @@ -1011,6 +1013,10 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
This function synchronizes in order to validate array elements
when ``check_finite=True``.

See Also
--------
:obj:`dpnp.linalg.lu_factor` : LU factorize a matrix.

Examples
--------
>>> import dpnp as np
Expand Down
6 changes: 4 additions & 2 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2508,12 +2508,14 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
if check_finite:
if not dpnp.isfinite(lu).all():
raise ValueError(
"array must not contain infs or NaNs.\n"
"LU factorization array must not contain infs or NaNs.\n"
"Note that when a singular matrix is given, unlike "
"dpnp.linalg.lu_factor returns an array containing NaN."
)
if not dpnp.isfinite(b).all():
raise ValueError("array must not contain infs or NaNs")
raise ValueError(
"Right-hand side array must not contain infs or NaNs"
)

lu_usm_arr = dpnp.get_usm_ndarray(lu)
piv_usm_arr = dpnp.get_usm_ndarray(piv)
Expand Down
6 changes: 6 additions & 0 deletions dpnp/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1931,6 +1931,7 @@ def test_overwrite_inplace(self, dtype):
)

assert lu is a_dp
assert lu.data.ptr == a_dp.data.ptr
assert lu.flags["F_CONTIGUOUS"] is True

L, U = self._split_lu(lu, 2, 2)
Expand All @@ -1948,6 +1949,7 @@ def test_overwrite_copy(self, dtype):
)

assert lu is not a_dp
assert lu.data.ptr != a_dp.data.ptr
assert lu.flags["F_CONTIGUOUS"] is True

L, U = self._split_lu(lu, 2, 2)
Expand All @@ -1974,6 +1976,7 @@ def test_overwrite_copy_special(self):
)

assert lu is not a_dp
assert lu.data.ptr != a_dp.data.ptr
assert lu.flags["F_CONTIGUOUS"] is True

L, U = self._split_lu(lu, 2, 2)
Expand Down Expand Up @@ -2217,6 +2220,7 @@ def test_overwrite_inplace(self, dtype):
)

assert x is b_dp
assert x.data.ptr == b_dp.data.ptr
assert dpnp.allclose(a_dp @ x, b_orig, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
Expand All @@ -2230,6 +2234,7 @@ def test_overwrite_copy_special(self, dtype):
(lu, piv), b1, overwrite_b=True, check_finite=False
)
assert x1 is not b1
assert x1.data.ptr != b1.data.ptr

# F-contig, match dtype but read-only input
b2 = dpnp.array([1, 0], dtype=dtype, order="F")
Expand All @@ -2238,6 +2243,7 @@ def test_overwrite_copy_special(self, dtype):
(lu, piv), b2, overwrite_b=True, check_finite=False
)
assert x2 is not b2
assert x2.data.ptr != b2.data.ptr

for x in (x1, x2):
assert dpnp.allclose(
Expand Down
6 changes: 3 additions & 3 deletions dpnp/tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,13 +1612,13 @@ def test_lu_factor(self, data, device):
assert_sycl_queue_equal(param_queue, a.sycl_queue)

@pytest.mark.parametrize(
"data",
"b_data",
[[1.0, 2.0], numpy.empty((2, 0))],
)
def test_lu_solve(self, data, device):
def test_lu_solve(self, b_data, device):
a = dpnp.array([[1.0, 2.0], [3.0, 5.0]], device=device)
lu, piv = dpnp.linalg.lu_factor(a)
b = dpnp.array(data, device=device)
b = dpnp.array(b_data, device=device)

result = dpnp.linalg.lu_solve((lu, piv), b)

Expand Down
8 changes: 4 additions & 4 deletions dpnp/tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,13 +1489,13 @@ def test_lu_factor(self, data, usm_type):

@pytest.mark.parametrize("usm_type_rhs", list_of_usm_types)
@pytest.mark.parametrize(
"data",
"b_data",
[[1.0, 2.0], numpy.empty((2, 0))],
)
def test_lu_solve(self, data, usm_type, usm_type_rhs):
a = dpnp.array(data, usm_type=usm_type)
def test_lu_solve(self, b_data, usm_type, usm_type_rhs):
a = dpnp.array([[1.0, 2.0], [3.0, 5.0]], usm_type=usm_type)
lu, piv = dpnp.linalg.lu_factor(a)
b = dpnp.array(data, usm_type=usm_type_rhs)
b = dpnp.array(b_data, usm_type=usm_type_rhs)

result = dpnp.linalg.lu_solve((lu, piv), b)

Expand Down
Loading