Skip to content

Commit 506d24d

Browse files
Add TestLuFactor
1 parent e781beb commit 506d24d

File tree

1 file changed

+133
-0
lines changed

1 file changed

+133
-0
lines changed

dpnp/tests/test_linalg.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616

1717
import dpnp
18+
import dpnp.linalg
1819

1920
from .helper import (
2021
assert_dtype_allclose,
@@ -1868,6 +1869,138 @@ def test_lstsq_errors(self):
18681869
assert_raises(TypeError, dpnp.linalg.lstsq, a_dp, b_dp, [-1])
18691870

18701871

1872+
class TestLuFactor:
1873+
@staticmethod
1874+
def _apply_pivots_rows(A_dp, piv_dp):
1875+
m = A_dp.shape[0]
1876+
rows = dpnp.arange(m)
1877+
for i in range(int(piv_dp.shape[0])):
1878+
r = int(piv_dp[i].item())
1879+
if i != r:
1880+
tmp = rows[i].copy()
1881+
rows[i] = rows[r]
1882+
rows[r] = tmp
1883+
return A_dp[rows]
1884+
1885+
@staticmethod
1886+
def _split_lu(lu, m, n):
1887+
L = dpnp.tril(lu, k=-1)
1888+
dpnp.fill_diagonal(L, 1)
1889+
L = L[:, : min(m, n)]
1890+
U = dpnp.triu(lu)[: min(m, n), :]
1891+
return L, U
1892+
1893+
@pytest.mark.parametrize(
1894+
"shape", [(1, 1), (2, 2), (3, 3), (1, 5), (5, 1), (2, 5), (5, 2)]
1895+
)
1896+
@pytest.mark.parametrize("order", ["C", "F"])
1897+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
1898+
def test_lu_factor(self, shape, order, dtype):
1899+
a_np = generate_random_numpy_array(shape, dtype, order)
1900+
a_dp = dpnp.array(a_np, order=order)
1901+
1902+
lu, piv = dpnp.linalg.lu_factor(
1903+
a_dp, check_finite=False, overwrite_a=False
1904+
)
1905+
1906+
# verify piv
1907+
assert piv.shape == (min(shape),)
1908+
assert piv.dtype == dpnp.int64
1909+
if shape[0] > 0:
1910+
assert int(dpnp.min(piv)) >= 0
1911+
assert int(dpnp.max(piv)) < shape[0]
1912+
1913+
m, n = shape
1914+
L, U = self._split_lu(lu, m, n)
1915+
LU = L @ U
1916+
1917+
A_cast = a_dp.astype(LU.dtype, copy=False)
1918+
PA = self._apply_pivots_rows(A_cast, piv)
1919+
1920+
assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
1921+
1922+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
1923+
def test_overwrite_inplace(self, dtype):
1924+
a_dp = dpnp.array([[4, 3], [6, 3]], dtype=dtype, order="F")
1925+
a_dp_orig = a_dp.copy()
1926+
lu, piv = dpnp.linalg.lu_factor(
1927+
a_dp, overwrite_a=True, check_finite=False
1928+
)
1929+
1930+
assert lu is a_dp
1931+
assert lu.flags["F_CONTIGUOUS"] is True
1932+
1933+
L, U = self._split_lu(lu, 2, 2)
1934+
PA = self._apply_pivots_rows(a_dp_orig, piv)
1935+
LU = L @ U
1936+
1937+
assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
1938+
1939+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
1940+
def test_overwrite_copy(self, dtype):
1941+
a_dp = dpnp.array([[4, 3], [6, 3]], dtype=dtype, order="C")
1942+
a_dp_orig = a_dp.copy()
1943+
lu, piv = dpnp.linalg.lu_factor(
1944+
a_dp, overwrite_a=True, check_finite=False
1945+
)
1946+
1947+
assert lu is not a_dp
1948+
assert lu.flags["F_CONTIGUOUS"] is True
1949+
1950+
L, U = self._split_lu(lu, 2, 2)
1951+
PA = self._apply_pivots_rows(a_dp_orig, piv)
1952+
LU = L @ U
1953+
1954+
assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
1955+
1956+
@pytest.mark.parametrize("shape", [(0, 0), (0, 2), (2, 0)])
1957+
def test_empty_inputs(self, shape):
1958+
a_dp = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F")
1959+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
1960+
assert lu.shape == shape
1961+
assert piv.shape == (min(shape),)
1962+
1963+
@pytest.mark.parametrize(
1964+
"sl",
1965+
[
1966+
(slice(None, None, 2), slice(None, None, 2)),
1967+
(slice(None, None, -1), slice(None, None, -1)),
1968+
],
1969+
)
1970+
def test_strided(self, sl):
1971+
base = (
1972+
numpy.arange(7 * 7, dtype=dpnp.default_float_type()).reshape(
1973+
7, 7, order="F"
1974+
)
1975+
+ 0.1
1976+
)
1977+
a_np = base[sl]
1978+
a_dp = dpnp.array(a_np)
1979+
1980+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
1981+
L, U = self._split_lu(lu, *a_dp.shape)
1982+
PA = self._apply_pivots_rows(a_dp, piv)
1983+
LU = L @ U
1984+
1985+
assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
1986+
1987+
def test_singular_matrix(self):
1988+
a_dp = dpnp.array([[1.0, 2.0], [2.0, 4.0]])
1989+
with pytest.warns(RuntimeWarning, match="Singular matrix"):
1990+
dpnp.linalg.lu_factor(a_dp, check_finite=False)
1991+
1992+
@pytest.mark.parametrize("bad", [numpy.inf, -numpy.inf, numpy.nan])
1993+
def test_check_finite_raises(self, bad):
1994+
a_dp = dpnp.array([[1.0, 2.0], [3.0, bad]], order="F")
1995+
assert_raises(
1996+
ValueError, dpnp.linalg.lu_factor, a_dp, check_finite=True
1997+
)
1998+
1999+
def test_batched_not_supported(self):
2000+
a_dp = dpnp.ones((2, 2, 2))
2001+
assert_raises(NotImplementedError, dpnp.linalg.lu_factor, a_dp)
2002+
2003+
18712004
class TestMatrixPower:
18722005
@pytest.mark.parametrize("dtype", get_all_dtypes())
18732006
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)