Skip to content

Commit d6477c7

Browse files
Add tests for batched lu_factor
1 parent 97b6202 commit d6477c7

File tree

3 files changed

+115
-2
lines changed

3 files changed

+115
-2
lines changed

dpnp/tests/test_linalg.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2018,6 +2018,119 @@ def test_batched_not_supported(self):
20182018
assert_raises(NotImplementedError, dpnp.linalg.lu_factor, a_dp)
20192019

20202020

2021+
class TestLuFactorBatched:
2022+
@staticmethod
2023+
def _apply_pivots_rows(A_dp, piv_dp):
2024+
m = A_dp.shape[0]
2025+
rows = dpnp.arange(m)
2026+
for i in range(int(piv_dp.shape[0])):
2027+
r = int(piv_dp[i].item())
2028+
if i != r:
2029+
tmp = rows[i].copy()
2030+
rows[i] = rows[r]
2031+
rows[r] = tmp
2032+
return A_dp[rows]
2033+
2034+
@staticmethod
2035+
def _split_lu(lu, m, n):
2036+
L = dpnp.tril(lu, k=-1)
2037+
dpnp.fill_diagonal(L, 1)
2038+
L = L[:, : min(m, n)]
2039+
U = dpnp.triu(lu)[: min(m, n), :]
2040+
return L, U
2041+
2042+
@pytest.mark.parametrize(
2043+
"shape",
2044+
[(2, 2, 2), (3, 4, 4), (2, 3, 5, 2), (4, 1, 3)],
2045+
ids=["(2,2,2)", "(3,4,4)", "(2,3,5,2)", "(4,1,3)"],
2046+
)
2047+
@pytest.mark.parametrize("order", ["C", "F"])
2048+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
2049+
def test_lu_factor_batched(self, shape, order, dtype):
2050+
a_np = generate_random_numpy_array(shape, dtype, order)
2051+
a_dp = dpnp.array(a_np, order=order)
2052+
2053+
lu, piv = dpnp.linalg.lu_factor(
2054+
a_dp, check_finite=False, overwrite_a=False
2055+
)
2056+
2057+
assert lu.shape == a_dp.shape
2058+
m, n = shape[-2], shape[-1]
2059+
assert piv.shape == (*shape[:-2], min(m, n))
2060+
assert piv.dtype == dpnp.int64
2061+
2062+
a_3d = a_dp.reshape((-1, m, n))
2063+
lu_3d = lu.reshape((-1, m, n))
2064+
piv_2d = piv.reshape((-1, min(m, n)))
2065+
for i in range(a_3d.shape[0]):
2066+
L, U = self._split_lu(lu_3d[i], m, n)
2067+
A_cast = a_3d[i].astype(L.dtype, copy=False)
2068+
PA = self._apply_pivots_rows(A_cast, piv_2d[i])
2069+
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
2070+
2071+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
2072+
@pytest.mark.parametrize("order", ["C", "F"])
2073+
def test_overwrite(self, dtype, order):
2074+
a_dp = dpnp.arange(2 * 2 * 3, dtype=dtype).reshape(3, 2, 2, order=order)
2075+
a_dp_orig = a_dp.copy()
2076+
lu, piv = dpnp.linalg.lu_factor(
2077+
a_dp, overwrite_a=True, check_finite=False
2078+
)
2079+
2080+
assert lu is not a_dp
2081+
assert_allclose(a_dp, a_dp_orig)
2082+
2083+
m = n = 2
2084+
lu_3d = lu.reshape((-1, m, n))
2085+
a_3d = a_dp.reshape((-1, m, n))
2086+
piv_2d = piv.reshape((-1, min(m, n)))
2087+
for i in range(a_3d.shape[0]):
2088+
L, U = self._split_lu(lu_3d[i], m, n)
2089+
A_cast = a_3d[i].astype(L.dtype, copy=False)
2090+
PA = self._apply_pivots_rows(A_cast, piv_2d[i])
2091+
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
2092+
2093+
@pytest.mark.parametrize(
2094+
"shape", [(0, 2, 2), (2, 0, 2), (2, 2, 0), (0, 0, 0)]
2095+
)
2096+
def test_empty_inputs(self, shape):
2097+
a = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F")
2098+
2099+
lu, piv = dpnp.linalg.lu_factor(a, check_finite=False)
2100+
assert lu.shape == shape
2101+
m, n = shape[-2:]
2102+
assert piv.shape == (*shape[:-2], min(m, n))
2103+
2104+
def test_strided(self):
2105+
a = (
2106+
dpnp.arange(5 * 3 * 3, dtype=dpnp.default_float_type()).reshape(
2107+
5, 3, 3, order="F"
2108+
)
2109+
+ 0.1
2110+
)
2111+
a_stride = a[::2]
2112+
lu, piv = dpnp.linalg.lu_factor(a_stride, check_finite=False)
2113+
for i in range(a_stride.shape[0]):
2114+
L, U = self._split_lu(lu[i], 3, 3)
2115+
PA = self._apply_pivots_rows(
2116+
a_stride[i].astype(L.dtype, copy=False), piv[i]
2117+
)
2118+
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
2119+
2120+
def test_singular_matrix(self):
2121+
a = dpnp.zeros((3, 2, 2), dtype=dpnp.float64)
2122+
a[0] = dpnp.array([[1.0, 2.0], [2.0, 4.0]])
2123+
a[1] = dpnp.eye(2)
2124+
a[2] = dpnp.array([[1.0, 1.0], [1.0, 1.0]])
2125+
with pytest.warns(RuntimeWarning, match="Singular matrix"):
2126+
dpnp.linalg.lu_factor(a, check_finite=False)
2127+
2128+
def test_check_finite_raises(self):
2129+
a = dpnp.ones((2, 3, 3), dtype=dpnp.float64, order="F")
2130+
a[1, 0, 0] = dpnp.nan
2131+
assert_raises(ValueError, dpnp.linalg.lu_factor, a, check_finite=True)
2132+
2133+
20212134
class TestMatrixPower:
20222135
@pytest.mark.parametrize("dtype", get_all_dtypes())
20232136
@pytest.mark.parametrize(

dpnp/tests/test_sycl_queue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1572,7 +1572,7 @@ def test_lstsq(self, m, n, nrhs, device):
15721572

15731573
@pytest.mark.parametrize(
15741574
"data",
1575-
[[[1.0, 2.0], [3.0, 5.0]], [[]]],
1575+
[[[1.0, 2.0], [3.0, 5.0]], [[]], [[[1.0, 2.0], [3.0, 5.0]]], [[[]]]],
15761576
)
15771577
def test_lu_factor(self, data, device):
15781578
a = dpnp.array(data, device=device)

dpnp/tests/test_usm_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1451,7 +1451,7 @@ def test_lstsq(self, m, n, nrhs, usm_type, usm_type_other):
14511451

14521452
@pytest.mark.parametrize(
14531453
"data",
1454-
[[[1.0, 2.0], [3.0, 5.0]], [[]]],
1454+
[[[1.0, 2.0], [3.0, 5.0]], [[]], [[[1.0, 2.0], [3.0, 5.0]]], [[[]]]],
14551455
)
14561456
def test_lu_factor(self, data, usm_type):
14571457
a = dpnp.array(data, usm_type=usm_type)

0 commit comments

Comments
 (0)