Skip to content

Commit bc50cbb

Browse files
Add tests for batched lu_factor
1 parent bde2d4e commit bc50cbb

File tree

3 files changed

+115
-2
lines changed

3 files changed

+115
-2
lines changed

dpnp/tests/test_linalg.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2022,6 +2022,119 @@ def test_batched_not_supported(self):
20222022
assert_raises(NotImplementedError, dpnp.linalg.lu_factor, a_dp)
20232023

20242024

2025+
class TestLuFactorBatched:
2026+
@staticmethod
2027+
def _apply_pivots_rows(A_dp, piv_dp):
2028+
m = A_dp.shape[0]
2029+
rows = dpnp.arange(m)
2030+
for i in range(int(piv_dp.shape[0])):
2031+
r = int(piv_dp[i].item())
2032+
if i != r:
2033+
tmp = rows[i].copy()
2034+
rows[i] = rows[r]
2035+
rows[r] = tmp
2036+
return A_dp[rows]
2037+
2038+
@staticmethod
2039+
def _split_lu(lu, m, n):
2040+
L = dpnp.tril(lu, k=-1)
2041+
dpnp.fill_diagonal(L, 1)
2042+
L = L[:, : min(m, n)]
2043+
U = dpnp.triu(lu)[: min(m, n), :]
2044+
return L, U
2045+
2046+
@pytest.mark.parametrize(
2047+
"shape",
2048+
[(2, 2, 2), (3, 4, 4), (2, 3, 5, 2), (4, 1, 3)],
2049+
ids=["(2,2,2)", "(3,4,4)", "(2,3,5,2)", "(4,1,3)"],
2050+
)
2051+
@pytest.mark.parametrize("order", ["C", "F"])
2052+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
2053+
def test_lu_factor_batched(self, shape, order, dtype):
2054+
a_np = generate_random_numpy_array(shape, dtype, order)
2055+
a_dp = dpnp.array(a_np, order=order)
2056+
2057+
lu, piv = dpnp.linalg.lu_factor(
2058+
a_dp, check_finite=False, overwrite_a=False
2059+
)
2060+
2061+
assert lu.shape == a_dp.shape
2062+
m, n = shape[-2], shape[-1]
2063+
assert piv.shape == (*shape[:-2], min(m, n))
2064+
assert piv.dtype == dpnp.int64
2065+
2066+
a_3d = a_dp.reshape((-1, m, n))
2067+
lu_3d = lu.reshape((-1, m, n))
2068+
piv_2d = piv.reshape((-1, min(m, n)))
2069+
for i in range(a_3d.shape[0]):
2070+
L, U = self._split_lu(lu_3d[i], m, n)
2071+
A_cast = a_3d[i].astype(L.dtype, copy=False)
2072+
PA = self._apply_pivots_rows(A_cast, piv_2d[i])
2073+
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
2074+
2075+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
2076+
@pytest.mark.parametrize("order", ["C", "F"])
2077+
def test_overwrite(self, dtype, order):
2078+
a_dp = dpnp.arange(2 * 2 * 3, dtype=dtype).reshape(3, 2, 2, order=order)
2079+
a_dp_orig = a_dp.copy()
2080+
lu, piv = dpnp.linalg.lu_factor(
2081+
a_dp, overwrite_a=True, check_finite=False
2082+
)
2083+
2084+
assert lu is not a_dp
2085+
assert_allclose(a_dp, a_dp_orig)
2086+
2087+
m = n = 2
2088+
lu_3d = lu.reshape((-1, m, n))
2089+
a_3d = a_dp.reshape((-1, m, n))
2090+
piv_2d = piv.reshape((-1, min(m, n)))
2091+
for i in range(a_3d.shape[0]):
2092+
L, U = self._split_lu(lu_3d[i], m, n)
2093+
A_cast = a_3d[i].astype(L.dtype, copy=False)
2094+
PA = self._apply_pivots_rows(A_cast, piv_2d[i])
2095+
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
2096+
2097+
@pytest.mark.parametrize(
2098+
"shape", [(0, 2, 2), (2, 0, 2), (2, 2, 0), (0, 0, 0)]
2099+
)
2100+
def test_empty_inputs(self, shape):
2101+
a = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F")
2102+
2103+
lu, piv = dpnp.linalg.lu_factor(a, check_finite=False)
2104+
assert lu.shape == shape
2105+
m, n = shape[-2:]
2106+
assert piv.shape == (*shape[:-2], min(m, n))
2107+
2108+
def test_strided(self):
2109+
a = (
2110+
dpnp.arange(5 * 3 * 3, dtype=dpnp.default_float_type()).reshape(
2111+
5, 3, 3, order="F"
2112+
)
2113+
+ 0.1
2114+
)
2115+
a_stride = a[::2]
2116+
lu, piv = dpnp.linalg.lu_factor(a_stride, check_finite=False)
2117+
for i in range(a_stride.shape[0]):
2118+
L, U = self._split_lu(lu[i], 3, 3)
2119+
PA = self._apply_pivots_rows(
2120+
a_stride[i].astype(L.dtype, copy=False), piv[i]
2121+
)
2122+
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
2123+
2124+
def test_singular_matrix(self):
2125+
a = dpnp.zeros((3, 2, 2), dtype=dpnp.float64)
2126+
a[0] = dpnp.array([[1.0, 2.0], [2.0, 4.0]])
2127+
a[1] = dpnp.eye(2)
2128+
a[2] = dpnp.array([[1.0, 1.0], [1.0, 1.0]])
2129+
with pytest.warns(RuntimeWarning, match="Singular matrix"):
2130+
dpnp.linalg.lu_factor(a, check_finite=False)
2131+
2132+
def test_check_finite_raises(self):
2133+
a = dpnp.ones((2, 3, 3), dtype=dpnp.float64, order="F")
2134+
a[1, 0, 0] = dpnp.nan
2135+
assert_raises(ValueError, dpnp.linalg.lu_factor, a, check_finite=True)
2136+
2137+
20252138
class TestMatrixPower:
20262139
@pytest.mark.parametrize("dtype", get_all_dtypes())
20272140
@pytest.mark.parametrize(

dpnp/tests/test_sycl_queue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1572,7 +1572,7 @@ def test_lstsq(self, m, n, nrhs, device):
15721572

15731573
@pytest.mark.parametrize(
15741574
"data",
1575-
[[[1.0, 2.0], [3.0, 5.0]], [[]]],
1575+
[[[1.0, 2.0], [3.0, 5.0]], [[]], [[[1.0, 2.0], [3.0, 5.0]]], [[[]]]],
15761576
)
15771577
def test_lu_factor(self, data, device):
15781578
a = dpnp.array(data, device=device)

dpnp/tests/test_usm_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1451,7 +1451,7 @@ def test_lstsq(self, m, n, nrhs, usm_type, usm_type_other):
14511451

14521452
@pytest.mark.parametrize(
14531453
"data",
1454-
[[[1.0, 2.0], [3.0, 5.0]], [[]]],
1454+
[[[1.0, 2.0], [3.0, 5.0]], [[]], [[[1.0, 2.0], [3.0, 5.0]]], [[[]]]],
14551455
)
14561456
def test_lu_factor(self, data, usm_type):
14571457
a = dpnp.array(data, usm_type=usm_type)

0 commit comments

Comments
 (0)