Skip to content

Commit b10a8d6

Browse files
Add TestLuSolve to test_linalg.py
1 parent 17b11ae commit b10a8d6

File tree

1 file changed

+209
-0
lines changed

1 file changed

+209
-0
lines changed

dpnp/tests/test_linalg.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2144,6 +2144,215 @@ def test_check_finite_raises(self):
21442144
assert_raises(ValueError, dpnp.linalg.lu_factor, a, check_finite=True)
21452145

21462146

2147+
class TestLuSolve:
2148+
@staticmethod
2149+
def _make_nonsingular_np(shape, dtype, order):
2150+
A = generate_random_numpy_array(shape, dtype, order)
2151+
m, n = shape
2152+
k = min(m, n)
2153+
for i in range(k):
2154+
off = numpy.sum(numpy.abs(A[i, :n])) - numpy.abs(A[i, i])
2155+
A[i, i] = A.dtype.type(off + 1.0)
2156+
return A
2157+
2158+
@pytest.mark.parametrize("shape", [(1, 1), (2, 2), (3, 3), (5, 5)])
2159+
@pytest.mark.parametrize("rhs_cols", [None, 1, 3])
2160+
@pytest.mark.parametrize("order", ["C", "F"])
2161+
@pytest.mark.parametrize(
2162+
"dtype", get_all_dtypes(no_bool=True, no_none=True)
2163+
)
2164+
def test_lu_solve(self, shape, rhs_cols, order, dtype):
2165+
a_np = self._make_nonsingular_np(shape, dtype, order)
2166+
a_dp = dpnp.array(a_np, order=order)
2167+
2168+
n = shape[0]
2169+
if rhs_cols is None:
2170+
b_np = generate_random_numpy_array((n,), dtype, order)
2171+
else:
2172+
b_np = generate_random_numpy_array((n, rhs_cols), dtype, order)
2173+
b_dp = dpnp.array(b_np, order=order)
2174+
2175+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
2176+
x = dpnp.linalg.lu_solve(
2177+
(lu, piv), b_dp, trans=0, overwrite_b=False, check_finite=False
2178+
)
2179+
2180+
# check A @ x = b
2181+
Ax = a_dp @ x
2182+
assert dpnp.allclose(Ax, b_dp, rtol=1e-6, atol=1e-6)
2183+
2184+
@pytest.mark.parametrize("trans", [0, 1, 2])
2185+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
2186+
def test_trans(self, trans, dtype):
2187+
n = 4
2188+
a_np = self._make_nonsingular_np((n, n), dtype, order="F")
2189+
a_dp = dpnp.array(a_np, order="F")
2190+
b_dp = dpnp.array(generate_random_numpy_array((n, 2), dtype, "F"))
2191+
2192+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
2193+
x = dpnp.linalg.lu_solve(
2194+
(lu, piv), b_dp, trans=trans, overwrite_b=False, check_finite=False
2195+
)
2196+
2197+
if trans == 0:
2198+
lhs = a_dp @ x
2199+
elif trans == 1:
2200+
lhs = a_dp.T @ x
2201+
else: # trans == 2
2202+
lhs = a_dp.conj().T @ x
2203+
2204+
assert dpnp.allclose(lhs, b_dp, rtol=1e-6, atol=1e-6)
2205+
2206+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
2207+
def test_overwrite_inplace(self, dtype):
2208+
a_dp = dpnp.array([[4, 3], [6, 3]], dtype=dtype, order="F")
2209+
b_dp = dpnp.array([1, 0], dtype=dtype, order="F")
2210+
b_orig = b_dp.copy()
2211+
2212+
lu, piv = dpnp.linalg.lu_factor(
2213+
a_dp, overwrite_a=False, check_finite=False
2214+
)
2215+
x = dpnp.linalg.lu_solve(
2216+
(lu, piv), b_dp, trans=0, overwrite_b=True, check_finite=False
2217+
)
2218+
2219+
assert x is b_dp
2220+
assert dpnp.allclose(a_dp @ x, b_orig, rtol=1e-6, atol=1e-6)
2221+
2222+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
2223+
def test_overwrite_copy_special(self, dtype):
2224+
a_dp = dpnp.array([[4, 3], [6, 3]], dtype=dtype, order="F")
2225+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
2226+
2227+
# F-contig but dtype != res_type
2228+
b1 = dpnp.array([1, 0], dtype=dpnp.int32, order="F")
2229+
x1 = dpnp.linalg.lu_solve(
2230+
(lu, piv), b1, overwrite_b=True, check_finite=False
2231+
)
2232+
assert x1 is not b1
2233+
2234+
# F-contig, match dtype but read-only input
2235+
b2 = dpnp.array([1, 0], dtype=dtype, order="F")
2236+
b2.flags["WRITABLE"] = False
2237+
x2 = dpnp.linalg.lu_solve(
2238+
(lu, piv), b2, overwrite_b=True, check_finite=False
2239+
)
2240+
assert x2 is not b2
2241+
2242+
for x in (x1, x2):
2243+
assert dpnp.allclose(
2244+
a_dp @ x,
2245+
dpnp.array([1, 0], dtype=x.dtype),
2246+
rtol=1e-6,
2247+
atol=1e-6,
2248+
)
2249+
2250+
@pytest.mark.parametrize(
2251+
"dtype_a", get_all_dtypes(no_bool=True, no_none=True)
2252+
)
2253+
@pytest.mark.parametrize(
2254+
"dtype_b", get_all_dtypes(no_bool=True, no_none=True)
2255+
)
2256+
def test_diff_type(self, dtype_a, dtype_b):
2257+
a_np = self._make_nonsingular_np((3, 3), dtype_a, order="F")
2258+
a_dp = dpnp.array(a_np, order="F")
2259+
2260+
b_np = generate_random_numpy_array((3,), dtype_b, order="F")
2261+
b_dp = dpnp.array(b_np, order="F")
2262+
2263+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
2264+
x = dpnp.linalg.lu_solve((lu, piv), b_dp, check_finite=False)
2265+
assert dpnp.allclose(
2266+
a_dp @ x, b_dp.astype(x.dtype, copy=False), rtol=1e-6, atol=1e-6
2267+
)
2268+
2269+
def test_strided_rhs(self):
2270+
n = 7
2271+
a_np = self._make_nonsingular_np(
2272+
(n, n), dpnp.default_float_type(), order="F"
2273+
)
2274+
a_dp = dpnp.array(a_np, order="F")
2275+
2276+
rhs_full = (
2277+
dpnp.arange(n * n, dtype=dpnp.default_float_type()).reshape(
2278+
n, n, order="F"
2279+
)
2280+
+ 1.0
2281+
)
2282+
b_dp = rhs_full[:, ::2][:, :3]
2283+
2284+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
2285+
x = dpnp.linalg.lu_solve(
2286+
(lu, piv), b_dp, overwrite_b=False, check_finite=False
2287+
)
2288+
2289+
assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-6, atol=1e-6)
2290+
2291+
@pytest.mark.skip("Not implemented yet")
2292+
@pytest.mark.parametrize(
2293+
"b_shape",
2294+
[
2295+
(4,),
2296+
(4, 1),
2297+
(4, 3),
2298+
# (1, 4, 3),
2299+
# (2, 4, 3),
2300+
# (1, 1, 4, 3)
2301+
],
2302+
)
2303+
def test_broadcast_rhs(self, b_shape):
2304+
dtype = dpnp.default_float_type()
2305+
2306+
a_np = self._make_nonsingular_np((4, 4), dtype, order="F")
2307+
a_dp = dpnp.array(a_np, order="F")
2308+
2309+
b_np = generate_random_numpy_array(b_shape, dtype, order="F")
2310+
b_dp = dpnp.array(b_np, order="F")
2311+
2312+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
2313+
x = dpnp.linalg.lu_solve(
2314+
(lu, piv), b_dp, overwrite_b=True, check_finite=False
2315+
)
2316+
2317+
assert x.shape == b_dp.shape
2318+
2319+
assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-6, atol=1e-6)
2320+
2321+
@pytest.mark.parametrize("shape", [(0, 0), (0, 5), (5, 5)])
2322+
@pytest.mark.parametrize("rhs_cols", [None, 0, 3])
2323+
def test_empty_shapes(self, shape, rhs_cols):
2324+
a_dp = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F")
2325+
if min(shape) > 0:
2326+
for i in range(min(shape)):
2327+
a_dp[i, i] = a_dp.dtype.type(1.0)
2328+
2329+
n = shape[0]
2330+
if rhs_cols is None:
2331+
b_shape = (n,)
2332+
else:
2333+
b_shape = (n, rhs_cols)
2334+
b_dp = dpnp.empty(b_shape, dtype=dpnp.default_float_type(), order="F")
2335+
2336+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
2337+
x = dpnp.linalg.lu_solve((lu, piv), b_dp, check_finite=False)
2338+
2339+
assert x.shape == b_shape
2340+
2341+
@pytest.mark.parametrize("bad", [numpy.inf, -numpy.inf, numpy.nan])
2342+
def test_check_finite_raises(self, bad):
2343+
a_dp = dpnp.array([[1.0, 0.0], [0.0, 1.0]], order="F")
2344+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
2345+
2346+
b_bad = dpnp.array([1.0, bad], order="F")
2347+
assert_raises(
2348+
ValueError,
2349+
dpnp.linalg.lu_solve,
2350+
(lu, piv),
2351+
b_bad,
2352+
check_finite=True,
2353+
)
2354+
2355+
21472356
class TestMatrixPower:
21482357
@pytest.mark.parametrize("dtype", get_all_dtypes())
21492358
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)