Skip to content

Commit 5e21a02

Browse files
Add TestLuSolveBatched
1 parent 14feed3 commit 5e21a02

File tree

1 file changed

+235
-3
lines changed

1 file changed

+235
-3
lines changed

dpnp/tests/test_linalg.py

Lines changed: 235 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2300,9 +2300,6 @@ def test_strided_rhs(self):
23002300
(4,),
23012301
(4, 1),
23022302
(4, 3),
2303-
# (1, 4, 3),
2304-
# (2, 4, 3),
2305-
# (1, 1, 4, 3)
23062303
],
23072304
)
23082305
def test_broadcast_rhs(self, b_shape):
@@ -2358,6 +2355,241 @@ def test_check_finite_raises(self, bad):
23582355
)
23592356

23602357

2358+
class TestLuSolveBatched:
2359+
@staticmethod
2360+
def _make_nonsingular_nd_np(shape, dtype, order):
2361+
A = generate_random_numpy_array(shape, dtype, order)
2362+
n = shape[-1]
2363+
A3 = A.reshape((-1, n, n))
2364+
for B in A3:
2365+
off = numpy.sum(numpy.abs(B), axis=1) - numpy.abs(numpy.diag(B))
2366+
B[numpy.arange(n), numpy.arange(n)] = A.dtype.type(off + 1.0)
2367+
A = A3.reshape(shape)
2368+
# Ensure reshapes did not break memory order
2369+
A = numpy.array(A, order=order)
2370+
return A
2371+
2372+
@staticmethod
2373+
def _expected_x_shape(a_shape, b_shape):
2374+
n = a_shape[-1]
2375+
assert a_shape[-2] == n
2376+
2377+
a_batch = a_shape[:-2]
2378+
if len(b_shape) >= 2 and b_shape[-2] == n:
2379+
# b : (..., n, nrhs)
2380+
k = b_shape[-1]
2381+
b_batch = b_shape[:-2]
2382+
exp_batch = numpy.broadcast_shapes(a_batch, b_batch)
2383+
return exp_batch + (n, k)
2384+
else:
2385+
# b : (..., n)
2386+
assert b_shape[-1] == n, "b's last dim must equal n"
2387+
b_batch = b_shape[:-1]
2388+
exp_batch = numpy.broadcast_shapes(a_batch, b_batch)
2389+
return exp_batch + (n,)
2390+
2391+
@pytest.mark.parametrize(
2392+
"a_shape, b_shape",
2393+
[
2394+
((1, 2, 2), (2,)),
2395+
((2, 4, 4), (4,)),
2396+
((2, 4, 4), (4, 3)),
2397+
((2, 4, 4), (2, 4, 4)),
2398+
((2, 4, 4), (1, 4, 3)),
2399+
((2, 4, 4), (2, 4, 2)),
2400+
((2, 3, 4, 4), (1, 3, 4, 2)),
2401+
((2, 3, 4, 4), (2, 1, 4, 2)),
2402+
((3, 4, 4), (1, 1, 4, 2)),
2403+
],
2404+
)
2405+
@pytest.mark.parametrize("order", ["C", "F"])
2406+
@pytest.mark.parametrize(
2407+
"dtype", get_all_dtypes(no_bool=True, no_none=True)
2408+
)
2409+
def test_lu_solve_batched(self, a_shape, b_shape, dtype, order):
2410+
a_np = self._make_nonsingular_nd_np(a_shape, dtype, order)
2411+
a_dp = dpnp.array(a_np, order=order)
2412+
2413+
b_np = generate_random_numpy_array(b_shape, dtype, order)
2414+
b_dp = dpnp.array(b_np, order=order)
2415+
2416+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
2417+
x = dpnp.linalg.lu_solve(
2418+
(lu, piv), b_dp, overwrite_b=True, check_finite=False
2419+
)
2420+
2421+
exp_shape = self._expected_x_shape(a_shape, b_shape)
2422+
assert x.shape == exp_shape
2423+
2424+
if b_dp.ndim > 1:
2425+
Ax = a_dp @ x
2426+
else:
2427+
Ax = (a_dp @ x[..., None])[..., 0]
2428+
b_exp = dpnp.broadcast_to(b_dp, exp_shape)
2429+
assert dpnp.allclose(Ax, b_exp, rtol=1e-5, atol=1e-5)
2430+
2431+
@pytest.mark.parametrize("trans", [0, 1, 2])
2432+
@pytest.mark.parametrize("order", ["C", "F"])
2433+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
2434+
def test_trans(self, trans, order, dtype):
2435+
a_shape = (3, 4, 4)
2436+
b_shape = (3, 4, 2)
2437+
2438+
a_np = self._make_nonsingular_nd_np(a_shape, dtype, order)
2439+
a_dp = dpnp.array(a_np, order=order)
2440+
b_dp = dpnp.array(
2441+
generate_random_numpy_array(b_shape, dtype, order), order=order
2442+
)
2443+
2444+
lu, piv = dpnp.linalg.lu_factor(
2445+
a_dp, overwrite_a=False, check_finite=False
2446+
)
2447+
x = dpnp.linalg.lu_solve(
2448+
(lu, piv), b_dp, trans=trans, overwrite_b=False, check_finite=False
2449+
)
2450+
2451+
if trans == 0:
2452+
lhs = a_dp @ x
2453+
elif trans == 1:
2454+
lhs = dpnp.swapaxes(a_dp, -1, -2) @ x
2455+
else: # trans == 2
2456+
lhs = dpnp.conj(dpnp.swapaxes(a_dp, -1, -2)) @ x
2457+
2458+
assert dpnp.allclose(lhs, b_dp, rtol=1e-5, atol=1e-5)
2459+
2460+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
2461+
@pytest.mark.parametrize("order", ["C", "F"])
2462+
def test_overwrite(self, dtype, order):
2463+
a_np = self._make_nonsingular_nd_np((2, 4, 4), dtype, order)
2464+
a_dp = dpnp.array(a_np, order=order)
2465+
2466+
lu, piv = dpnp.linalg.lu_factor(
2467+
a_dp, overwrite_a=False, check_finite=False
2468+
)
2469+
2470+
b_dp = dpnp.array(
2471+
generate_random_numpy_array((2, 4, 2), dtype, "F"), order="F"
2472+
)
2473+
b_dp_orig = b_dp.copy()
2474+
x = dpnp.linalg.lu_solve(
2475+
(lu, piv), b_dp, overwrite_b=True, check_finite=False
2476+
)
2477+
2478+
assert x is not b_dp
2479+
assert dpnp.allclose(b_dp, b_dp_orig)
2480+
2481+
assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-5, atol=1e-5)
2482+
2483+
def test_strided(self):
2484+
n, B = 4, 6
2485+
a_np = self._make_nonsingular_nd_np(
2486+
(B, n, n), dpnp.default_float_type(), "F"
2487+
)
2488+
a_dp = dpnp.array(a_np, order="F")
2489+
2490+
a_stride = a_dp[::2]
2491+
rhs_full = (
2492+
dpnp.arange(B * n * 3, dtype=dpnp.default_float_type()).reshape(
2493+
B, n, 3, order="F"
2494+
)
2495+
+ 1.0
2496+
)
2497+
b_dp = rhs_full[::2, :, ::-1]
2498+
2499+
lu, piv = dpnp.linalg.lu_factor(a_stride, check_finite=False)
2500+
x = dpnp.linalg.lu_solve(
2501+
(lu, piv), b_dp, overwrite_b=False, check_finite=False
2502+
)
2503+
2504+
assert dpnp.allclose(a_stride @ x, b_dp, rtol=1e-5, atol=1e-5)
2505+
2506+
@pytest.mark.parametrize(
2507+
"dtype_a", get_all_dtypes(no_bool=True, no_none=True)
2508+
)
2509+
@pytest.mark.parametrize(
2510+
"dtype_b", get_all_dtypes(no_bool=True, no_none=True)
2511+
)
2512+
@pytest.mark.parametrize("b_shape", [(4, 2), (1, 4, 2), (2, 4, 2)])
2513+
def test_diff_type(self, dtype_a, dtype_b, b_shape):
2514+
B, n, k = 2, 4, 2
2515+
a_np = self._make_nonsingular_nd_np((B, n, n), dtype_a, "F")
2516+
a_dp = dpnp.array(a_np, order="F")
2517+
2518+
b_np = generate_random_numpy_array(b_shape, dtype_b, "F")
2519+
b_dp = dpnp.array(b_np, order="F")
2520+
2521+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
2522+
x = dpnp.linalg.lu_solve((lu, piv), b_dp, check_finite=False)
2523+
2524+
exp_shape = (B, n, k)
2525+
assert x.shape == exp_shape
2526+
2527+
b_exp = dpnp.broadcast_to(b_dp, exp_shape)
2528+
assert dpnp.allclose(
2529+
a_dp @ x, b_exp.astype(x.dtype, copy=False), rtol=1e-5, atol=1e-5
2530+
)
2531+
2532+
@pytest.mark.parametrize(
2533+
"a_shape, b_shape",
2534+
[
2535+
((0, 3, 3), (0, 3)),
2536+
((2, 0, 0), (2, 0)),
2537+
((0, 0, 0), (0, 0)),
2538+
],
2539+
)
2540+
def test_empty_inputs(self, a_shape, b_shape):
2541+
a = dpnp.empty(a_shape, dtype=dpnp.default_float_type(), order="F")
2542+
b = dpnp.empty(b_shape, dtype=dpnp.default_float_type(), order="F")
2543+
2544+
lu, piv = dpnp.linalg.lu_factor(a, check_finite=False)
2545+
x = dpnp.linalg.lu_solve((lu, piv), b, check_finite=False)
2546+
2547+
assert x.shape == b_shape
2548+
2549+
def test_check_finite_raises(self):
2550+
B, n = 2, 3
2551+
a_np = self._make_nonsingular_nd_np(
2552+
(B, n, n), dpnp.default_float_type(), "F"
2553+
)
2554+
a_dp = dpnp.array(a_np, order="F")
2555+
lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False)
2556+
2557+
b_bad = dpnp.ones((B, n), dtype=dpnp.default_float_type(), order="F")
2558+
b_bad[1, 0] = dpnp.nan
2559+
assert_raises(
2560+
ValueError,
2561+
dpnp.linalg.lu_solve,
2562+
(lu, piv),
2563+
b_bad,
2564+
check_finite=True,
2565+
)
2566+
2567+
@pytest.mark.parametrize(
2568+
"a_shape, b_shape",
2569+
[
2570+
((2, 4, 4), (2,)),
2571+
((2, 4, 4), (2, 4)),
2572+
((2, 4, 4), (4, 4, 2)),
2573+
((2, 4, 4), (2, 3, 4, 2)),
2574+
((2, 3, 4, 4), (3, 4)),
2575+
((2, 3, 4, 4), (2, 4)),
2576+
((2, 3, 4, 4), (2, 3, 5, 2)),
2577+
],
2578+
)
2579+
def test_invalid_shapes(self, a_shape, b_shape):
2580+
dtype = dpnp.default_float_type()
2581+
a = dpnp.array(
2582+
self._make_nonsingular_nd_np(a_shape, dtype, "F"), order="F"
2583+
)
2584+
b = dpnp.array(
2585+
generate_random_numpy_array(b_shape, dtype, "F"), order="F"
2586+
)
2587+
2588+
lu, piv = dpnp.linalg.lu_factor(a, check_finite=False)
2589+
with pytest.raises(ValueError):
2590+
dpnp.linalg.lu_solve((lu, piv), b, check_finite=False)
2591+
2592+
23612593
class TestMatrixPower:
23622594
@pytest.mark.parametrize("dtype", get_all_dtypes())
23632595
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)