Skip to content

Commit 6541bf0

Browse files
Add TestLuSolve to test_linalg.py
1 parent a979b1a commit 6541bf0

File tree

1 file changed

+209
-0
lines changed

1 file changed

+209
-0
lines changed

dpnp/tests/test_linalg.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2135,6 +2135,215 @@ def test_check_finite_raises(self):
21352135
assert_raises(ValueError, dpnp.linalg.lu_factor, a, check_finite=True)
21362136

21372137

2138+
class TestLuSolve:
2139+
@staticmethod
2140+
def _make_nonsingular_np(shape, dtype, order):
2141+
A = generate_random_numpy_array(shape, dtype, order)
2142+
m, n = shape
2143+
k = min(m, n)
2144+
for i in range(k):
2145+
off = numpy.sum(numpy.abs(A[i, :n])) - numpy.abs(A[i, i])
2146+
A[i, i] = A.dtype.type(off + 1.0)
2147+
return A
2148+
2149+
@pytest.mark.parametrize("shape", [(1, 1), (2, 2), (3, 3), (5, 5)])
2150+
@pytest.mark.parametrize("rhs_cols", [None, 1, 3])
2151+
@pytest.mark.parametrize("order", ["C", "F"])
2152+
@pytest.mark.parametrize(
2153+
"dtype", get_all_dtypes(no_bool=True, no_none=True)
2154+
)
2155+
def test_lu_solve(self, shape, rhs_cols, order, dtype):
2156+
a_np = self._make_nonsingular_np(shape, dtype, order)
2157+
a_dp = dpnp.array(a_np, order=order)
2158+
2159+
n = shape[0]
2160+
if rhs_cols is None:
2161+
b_np = generate_random_numpy_array((n,), dtype, order)
2162+
else:
2163+
b_np = generate_random_numpy_array((n, rhs_cols), dtype, order)
2164+
b_dp = dpnp.array(b_np, order=order)
2165+
2166+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
2167+
x = dpnp.linalg.lu_solve(
2168+
(lu, piv), b_dp, trans=0, overwrite_b=False, check_finite=False
2169+
)
2170+
2171+
# check A @ x = b
2172+
Ax = a_dp @ x
2173+
assert dpnp.allclose(Ax, b_dp, rtol=1e-6, atol=1e-6)
2174+
2175+
@pytest.mark.parametrize("trans", [0, 1, 2])
2176+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
2177+
def test_trans(self, trans, dtype):
2178+
n = 4
2179+
a_np = self._make_nonsingular_np((n, n), dtype, order="F")
2180+
a_dp = dpnp.array(a_np, order="F")
2181+
b_dp = dpnp.array(generate_random_numpy_array((n, 2), dtype, "F"))
2182+
2183+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
2184+
x = dpnp.linalg.lu_solve(
2185+
(lu, piv), b_dp, trans=trans, overwrite_b=False, check_finite=False
2186+
)
2187+
2188+
if trans == 0:
2189+
lhs = a_dp @ x
2190+
elif trans == 1:
2191+
lhs = a_dp.T @ x
2192+
else: # trans == 2
2193+
lhs = a_dp.conj().T @ x
2194+
2195+
assert dpnp.allclose(lhs, b_dp, rtol=1e-6, atol=1e-6)
2196+
2197+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
2198+
def test_overwrite_inplace(self, dtype):
2199+
a_dp = dpnp.array([[4, 3], [6, 3]], dtype=dtype, order="F")
2200+
b_dp = dpnp.array([1, 0], dtype=dtype, order="F")
2201+
b_orig = b_dp.copy()
2202+
2203+
lu, piv = dpnp.linalg.lu_factor(
2204+
a_dp, overwrite_a=False, check_finite=False
2205+
)
2206+
x = dpnp.linalg.lu_solve(
2207+
(lu, piv), b_dp, trans=0, overwrite_b=True, check_finite=False
2208+
)
2209+
2210+
assert x is b_dp
2211+
assert dpnp.allclose(a_dp @ x, b_orig, rtol=1e-6, atol=1e-6)
2212+
2213+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
2214+
def test_overwrite_copy_special(self, dtype):
2215+
a_dp = dpnp.array([[4, 3], [6, 3]], dtype=dtype, order="F")
2216+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
2217+
2218+
# F-contig but dtype != res_type
2219+
b1 = dpnp.array([1, 0], dtype=dpnp.int32, order="F")
2220+
x1 = dpnp.linalg.lu_solve(
2221+
(lu, piv), b1, overwrite_b=True, check_finite=False
2222+
)
2223+
assert x1 is not b1
2224+
2225+
# F-contig, match dtype but read-only input
2226+
b2 = dpnp.array([1, 0], dtype=dtype, order="F")
2227+
b2.flags["WRITABLE"] = False
2228+
x2 = dpnp.linalg.lu_solve(
2229+
(lu, piv), b2, overwrite_b=True, check_finite=False
2230+
)
2231+
assert x2 is not b2
2232+
2233+
for x in (x1, x2):
2234+
assert dpnp.allclose(
2235+
a_dp @ x,
2236+
dpnp.array([1, 0], dtype=x.dtype),
2237+
rtol=1e-6,
2238+
atol=1e-6,
2239+
)
2240+
2241+
@pytest.mark.parametrize(
2242+
"dtype_a", get_all_dtypes(no_bool=True, no_none=True)
2243+
)
2244+
@pytest.mark.parametrize(
2245+
"dtype_b", get_all_dtypes(no_bool=True, no_none=True)
2246+
)
2247+
def test_diff_type(self, dtype_a, dtype_b):
2248+
a_np = self._make_nonsingular_np((3, 3), dtype_a, order="F")
2249+
a_dp = dpnp.array(a_np, order="F")
2250+
2251+
b_np = generate_random_numpy_array((3,), dtype_b, order="F")
2252+
b_dp = dpnp.array(b_np, order="F")
2253+
2254+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
2255+
x = dpnp.linalg.lu_solve((lu, piv), b_dp, check_finite=False)
2256+
assert dpnp.allclose(
2257+
a_dp @ x, b_dp.astype(x.dtype, copy=False), rtol=1e-6, atol=1e-6
2258+
)
2259+
2260+
def test_strided_rhs(self):
2261+
n = 7
2262+
a_np = self._make_nonsingular_np(
2263+
(n, n), dpnp.default_float_type(), order="F"
2264+
)
2265+
a_dp = dpnp.array(a_np, order="F")
2266+
2267+
rhs_full = (
2268+
dpnp.arange(n * n, dtype=dpnp.default_float_type()).reshape(
2269+
n, n, order="F"
2270+
)
2271+
+ 1.0
2272+
)
2273+
b_dp = rhs_full[:, ::2][:, :3]
2274+
2275+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
2276+
x = dpnp.linalg.lu_solve(
2277+
(lu, piv), b_dp, overwrite_b=False, check_finite=False
2278+
)
2279+
2280+
assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-6, atol=1e-6)
2281+
2282+
@pytest.mark.skip("Not implemented yet")
2283+
@pytest.mark.parametrize(
2284+
"b_shape",
2285+
[
2286+
(4,),
2287+
(4, 1),
2288+
(4, 3),
2289+
# (1, 4, 3),
2290+
# (2, 4, 3),
2291+
# (1, 1, 4, 3)
2292+
],
2293+
)
2294+
def test_broadcast_rhs(self, b_shape):
2295+
dtype = dpnp.default_float_type()
2296+
2297+
a_np = self._make_nonsingular_np((4, 4), dtype, order="F")
2298+
a_dp = dpnp.array(a_np, order="F")
2299+
2300+
b_np = generate_random_numpy_array(b_shape, dtype, order="F")
2301+
b_dp = dpnp.array(b_np, order="F")
2302+
2303+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
2304+
x = dpnp.linalg.lu_solve(
2305+
(lu, piv), b_dp, overwrite_b=True, check_finite=False
2306+
)
2307+
2308+
assert x.shape == b_dp.shape
2309+
2310+
assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-6, atol=1e-6)
2311+
2312+
@pytest.mark.parametrize("shape", [(0, 0), (0, 5), (5, 5)])
2313+
@pytest.mark.parametrize("rhs_cols", [None, 0, 3])
2314+
def test_empty_shapes(self, shape, rhs_cols):
2315+
a_dp = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F")
2316+
if min(shape) > 0:
2317+
for i in range(min(shape)):
2318+
a_dp[i, i] = a_dp.dtype.type(1.0)
2319+
2320+
n = shape[0]
2321+
if rhs_cols is None:
2322+
b_shape = (n,)
2323+
else:
2324+
b_shape = (n, rhs_cols)
2325+
b_dp = dpnp.empty(b_shape, dtype=dpnp.default_float_type(), order="F")
2326+
2327+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
2328+
x = dpnp.linalg.lu_solve((lu, piv), b_dp, check_finite=False)
2329+
2330+
assert x.shape == b_shape
2331+
2332+
@pytest.mark.parametrize("bad", [numpy.inf, -numpy.inf, numpy.nan])
2333+
def test_check_finite_raises(self, bad):
2334+
a_dp = dpnp.array([[1.0, 0.0], [0.0, 1.0]], order="F")
2335+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
2336+
2337+
b_bad = dpnp.array([1.0, bad], order="F")
2338+
assert_raises(
2339+
ValueError,
2340+
dpnp.linalg.lu_solve,
2341+
(lu, piv),
2342+
b_bad,
2343+
check_finite=True,
2344+
)
2345+
2346+
21382347
class TestMatrixPower:
21392348
@pytest.mark.parametrize("dtype", get_all_dtypes())
21402349
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)