|
15 | 15 | ) |
16 | 16 |
|
17 | 17 | import dpnp |
| 18 | +import dpnp.linalg |
18 | 19 |
|
19 | 20 | from .helper import ( |
20 | 21 | assert_dtype_allclose, |
@@ -1868,6 +1869,138 @@ def test_lstsq_errors(self): |
1868 | 1869 | assert_raises(TypeError, dpnp.linalg.lstsq, a_dp, b_dp, [-1]) |
1869 | 1870 |
|
1870 | 1871 |
|
| 1872 | +class TestLuFactor: |
| 1873 | + @staticmethod |
| 1874 | + def _apply_pivots_rows(A_dp, piv_dp): |
| 1875 | + m = A_dp.shape[0] |
| 1876 | + rows = dpnp.arange(m) |
| 1877 | + for i in range(int(piv_dp.shape[0])): |
| 1878 | + r = int(piv_dp[i].item()) |
| 1879 | + if i != r: |
| 1880 | + tmp = rows[i].copy() |
| 1881 | + rows[i] = rows[r] |
| 1882 | + rows[r] = tmp |
| 1883 | + return A_dp[rows] |
| 1884 | + |
| 1885 | + @staticmethod |
| 1886 | + def _split_lu(lu, m, n): |
| 1887 | + L = dpnp.tril(lu, k=-1) |
| 1888 | + dpnp.fill_diagonal(L, 1) |
| 1889 | + L = L[:, : min(m, n)] |
| 1890 | + U = dpnp.triu(lu)[: min(m, n), :] |
| 1891 | + return L, U |
| 1892 | + |
| 1893 | + @pytest.mark.parametrize( |
| 1894 | + "shape", [(1, 1), (2, 2), (3, 3), (1, 5), (5, 1), (2, 5), (5, 2)] |
| 1895 | + ) |
| 1896 | + @pytest.mark.parametrize("order", ["C", "F"]) |
| 1897 | + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) |
| 1898 | + def test_lu_factor(self, shape, order, dtype): |
| 1899 | + a_np = generate_random_numpy_array(shape, dtype, order) |
| 1900 | + a_dp = dpnp.array(a_np, order=order) |
| 1901 | + |
| 1902 | + lu, piv = dpnp.linalg.lu_factor( |
| 1903 | + a_dp, check_finite=False, overwrite_a=False |
| 1904 | + ) |
| 1905 | + |
| 1906 | + # verify piv |
| 1907 | + assert piv.shape == (min(shape),) |
| 1908 | + assert piv.dtype == dpnp.int64 |
| 1909 | + if shape[0] > 0: |
| 1910 | + assert int(dpnp.min(piv)) >= 0 |
| 1911 | + assert int(dpnp.max(piv)) < shape[0] |
| 1912 | + |
| 1913 | + m, n = shape |
| 1914 | + L, U = self._split_lu(lu, m, n) |
| 1915 | + LU = L @ U |
| 1916 | + |
| 1917 | + A_cast = a_dp.astype(LU.dtype, copy=False) |
| 1918 | + PA = self._apply_pivots_rows(A_cast, piv) |
| 1919 | + |
| 1920 | + assert_allclose(LU, PA, rtol=1e-6, atol=1e-6) |
| 1921 | + |
| 1922 | + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) |
| 1923 | + def test_overwrite_inplace(self, dtype): |
| 1924 | + a_dp = dpnp.array([[4, 3], [6, 3]], dtype=dtype, order="F") |
| 1925 | + a_dp_orig = a_dp.copy() |
| 1926 | + lu, piv = dpnp.linalg.lu_factor( |
| 1927 | + a_dp, overwrite_a=True, check_finite=False |
| 1928 | + ) |
| 1929 | + |
| 1930 | + assert lu is a_dp |
| 1931 | + assert lu.flags["F_CONTIGUOUS"] is True |
| 1932 | + |
| 1933 | + L, U = self._split_lu(lu, 2, 2) |
| 1934 | + PA = self._apply_pivots_rows(a_dp_orig, piv) |
| 1935 | + LU = L @ U |
| 1936 | + |
| 1937 | + assert_allclose(LU, PA, rtol=1e-6, atol=1e-6) |
| 1938 | + |
| 1939 | + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) |
| 1940 | + def test_overwrite_copy(self, dtype): |
| 1941 | + a_dp = dpnp.array([[4, 3], [6, 3]], dtype=dtype, order="C") |
| 1942 | + a_dp_orig = a_dp.copy() |
| 1943 | + lu, piv = dpnp.linalg.lu_factor( |
| 1944 | + a_dp, overwrite_a=True, check_finite=False |
| 1945 | + ) |
| 1946 | + |
| 1947 | + assert lu is not a_dp |
| 1948 | + assert lu.flags["F_CONTIGUOUS"] is True |
| 1949 | + |
| 1950 | + L, U = self._split_lu(lu, 2, 2) |
| 1951 | + PA = self._apply_pivots_rows(a_dp_orig, piv) |
| 1952 | + LU = L @ U |
| 1953 | + |
| 1954 | + assert_allclose(LU, PA, rtol=1e-6, atol=1e-6) |
| 1955 | + |
| 1956 | + @pytest.mark.parametrize("shape", [(0, 0), (0, 2), (2, 0)]) |
| 1957 | + def test_empty_inputs(self, shape): |
| 1958 | + a_dp = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F") |
| 1959 | + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) |
| 1960 | + assert lu.shape == shape |
| 1961 | + assert piv.shape == (min(shape),) |
| 1962 | + |
| 1963 | + @pytest.mark.parametrize( |
| 1964 | + "sl", |
| 1965 | + [ |
| 1966 | + (slice(None, None, 2), slice(None, None, 2)), |
| 1967 | + (slice(None, None, -1), slice(None, None, -1)), |
| 1968 | + ], |
| 1969 | + ) |
| 1970 | + def test_strided(self, sl): |
| 1971 | + base = ( |
| 1972 | + numpy.arange(7 * 7, dtype=dpnp.default_float_type()).reshape( |
| 1973 | + 7, 7, order="F" |
| 1974 | + ) |
| 1975 | + + 0.1 |
| 1976 | + ) |
| 1977 | + a_np = base[sl] |
| 1978 | + a_dp = dpnp.array(a_np) |
| 1979 | + |
| 1980 | + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) |
| 1981 | + L, U = self._split_lu(lu, *a_dp.shape) |
| 1982 | + PA = self._apply_pivots_rows(a_dp, piv) |
| 1983 | + LU = L @ U |
| 1984 | + |
| 1985 | + assert_allclose(LU, PA, rtol=1e-6, atol=1e-6) |
| 1986 | + |
| 1987 | + def test_singular_matrix(self): |
| 1988 | + a_dp = dpnp.array([[1.0, 2.0], [2.0, 4.0]]) |
| 1989 | + with pytest.warns(RuntimeWarning, match="Singular matrix"): |
| 1990 | + dpnp.linalg.lu_factor(a_dp, check_finite=False) |
| 1991 | + |
| 1992 | + @pytest.mark.parametrize("bad", [numpy.inf, -numpy.inf, numpy.nan]) |
| 1993 | + def test_check_finite_raises(self, bad): |
| 1994 | + a_dp = dpnp.array([[1.0, 2.0], [3.0, bad]], order="F") |
| 1995 | + assert_raises( |
| 1996 | + ValueError, dpnp.linalg.lu_factor, a_dp, check_finite=True |
| 1997 | + ) |
| 1998 | + |
| 1999 | + def test_batched_not_supported(self): |
| 2000 | + a_dp = dpnp.ones((2, 2, 2)) |
| 2001 | + assert_raises(NotImplementedError, dpnp.linalg.lu_factor, a_dp) |
| 2002 | + |
| 2003 | + |
1871 | 2004 | class TestMatrixPower: |
1872 | 2005 | @pytest.mark.parametrize("dtype", get_all_dtypes()) |
1873 | 2006 | @pytest.mark.parametrize( |
|
0 commit comments