Skip to content

Commit 94197e0

Browse files
authored
ENH: linalg.solve: add assume_a='banded' (scipy#21726)
1 parent 35ad591 commit 94197e0

File tree

2 files changed

+65
-31
lines changed

2 files changed

+65
-31
lines changed

scipy/linalg/_basic.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,24 @@ def _solve_check(n, info, lamch=None, rcond=None):
5151

5252
def _find_matrix_structure(a):
5353
n = a.shape[0]
54-
below, above = bandwidth(a)
55-
56-
if below == above == 0:
57-
return 'diagonal'
58-
elif above == 0:
59-
return 'lower triangular'
60-
elif below == 0:
61-
return 'upper triangular'
62-
elif above <= 1 and below <= 1 and n > 3:
63-
return 'tridiagonal'
64-
65-
if np.issubdtype(a.dtype, np.complexfloating) and ishermitian(a):
66-
return 'hermitian'
54+
n_below, n_above = bandwidth(a)
55+
56+
if n_below == n_above == 0:
57+
kind = 'diagonal'
58+
elif n_above == 0:
59+
kind = 'lower triangular'
60+
elif n_below == 0:
61+
kind = 'upper triangular'
62+
elif n_above <= 1 and n_below <= 1 and n > 3:
63+
kind = 'tridiagonal'
64+
elif np.issubdtype(a.dtype, np.complexfloating) and ishermitian(a):
65+
kind = 'hermitian'
6766
elif issymmetric(a):
68-
return 'symmetric'
67+
kind = 'symmetric'
68+
else:
69+
kind = 'general'
6970

70-
return 'general'
71+
return kind, n_below, n_above
7172

7273

7374
def solve(a, b, lower=False, overwrite_a=False,
@@ -84,6 +85,7 @@ def solve(a, b, lower=False, overwrite_a=False,
8485
=================== ================================
8586
diagonal 'diagonal'
8687
tridiagonal 'tridiagonal'
88+
banded 'banded'
8789
upper triangular 'upper triangular'
8890
lower triangular 'lower triangular'
8991
symmetric 'symmetric' (or 'sym')
@@ -201,7 +203,7 @@ def solve(a, b, lower=False, overwrite_a=False,
201203
b1 = b1[:, None]
202204
b_is_1D = True
203205

204-
if assume_a not in {None, 'diagonal', 'tridiagonal', 'lower triangular',
206+
if assume_a not in {None, 'diagonal', 'tridiagonal', 'banded', 'lower triangular',
205207
'upper triangular', 'symmetric', 'hermitian',
206208
'positive definite', 'general', 'sym', 'her', 'pos', 'gen'}:
207209
raise ValueError(f'{assume_a} is not a recognized matrix structure')
@@ -211,8 +213,9 @@ def solve(a, b, lower=False, overwrite_a=False,
211213
if assume_a in {'hermitian', 'her'} and not np.iscomplexobj(a1):
212214
assume_a = 'symmetric'
213215

216+
n_below, n_above = None, None
214217
if assume_a is None:
215-
assume_a = _find_matrix_structure(a1)
218+
assume_a, n_below, n_above = _find_matrix_structure(a1)
216219

217220
# Get the correct lamch function.
218221
# The LAMCH functions only exists for S and D
@@ -301,6 +304,20 @@ def solve(a, b, lower=False, overwrite_a=False,
301304
x, info = _gttrs(dl, d, du, du2, ipiv, b1, overwrite_b=overwrite_b)
302305
_solve_check(n, info)
303306
rcond, info = _gtcon(dl, d, du, du2, ipiv, anorm)
307+
# Banded case
308+
elif assume_a == 'banded':
309+
a1, n_below, n_above = ((a1.T, n_above, n_below) if transposed
310+
else (a1, n_below, n_above))
311+
n_below, n_above = bandwidth(a1) if n_below is None else (n_below, n_above)
312+
ab = _to_banded(n_below, n_above, a1)
313+
gbsv, = get_lapack_funcs(('gbsv',), (a1, b1))
314+
# Next two lines copied from `solve_banded`
315+
a2 = np.zeros((2*n_below + n_above + 1, ab.shape[1]), dtype=gbsv.dtype)
316+
a2[n_below:, :] = ab
317+
_, _, x, info = gbsv(n_below, n_above, a2, b1,
318+
overwrite_ab=True, overwrite_b=overwrite_b)
319+
_solve_check(n, info)
320+
# TODO: wrap gbcon and use to get rcond
304321
# Triangular case
305322
elif assume_a in {'lower triangular', 'upper triangular'}:
306323
lower = assume_a == 'lower triangular'
@@ -363,6 +380,18 @@ def _matrix_norm_general(norm, a, check_finite):
363380
return lange(norm, a)
364381

365382

383+
def _to_banded(n_below, n_above, a):
384+
n = a.shape[0]
385+
rows = n_above + n_below + 1
386+
ab = np.zeros((rows, n), dtype=a.dtype)
387+
ab[n_above] = np.diag(a)
388+
for i in range(1, n_above + 1):
389+
ab[n_above - i, i:] = np.diag(a, i)
390+
for i in range(1, n_below + 1):
391+
ab[n_above + i, :-i] = np.diag(a, -i)
392+
return ab
393+
394+
366395
def _ensure_dtype_cdsz(*arrays):
367396
# Ensure that the dtype of arrays is one of the standard types
368397
# compatible with LAPACK functions (single or double precision

scipy/linalg/tests/test_basic.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -882,18 +882,19 @@ def test_empty_rhs(self):
882882
assert_(x.shape == (2, 0), 'Returned empty array shape is wrong')
883883

884884
@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
885-
@pytest.mark.parametrize('assume_a', ['diagonal', 'tridiagonal', 'lower triangular',
886-
'upper triangular', 'symmetric', 'hermitian',
887-
'positive definite', 'general',
888-
'sym', 'her', 'pos', 'gen'])
885+
# "pos" and "positive definite" need to be added
886+
@pytest.mark.parametrize('assume_a', ['diagonal', 'tridiagonal', 'banded',
887+
'lower triangular', 'upper triangular',
888+
'symmetric', 'hermitian',
889+
'general', 'sym', 'her', 'gen'])
889890
@pytest.mark.parametrize('nrhs', [(), (5,)])
890891
@pytest.mark.parametrize('transposed', [True, False])
891892
@pytest.mark.parametrize('overwrite', [True, False])
892893
@pytest.mark.parametrize('fortran', [True, False])
893894
def test_structure_detection(self, dtype, assume_a, nrhs, transposed,
894895
overwrite, fortran):
895896
rng = np.random.default_rng(982345982439826)
896-
n = 5
897+
n = 5 if not assume_a == 'banded' else 20
897898
b = rng.random(size=(n,) + nrhs)
898899
A = rng.random(size=(n, n))
899900

@@ -911,6 +912,8 @@ def test_structure_detection(self, dtype, assume_a, nrhs, transposed,
911912
A = (np.diag(np.diag(A))
912913
+ np.diag(np.diag(A, -1), -1)
913914
+ np.diag(np.diag(A, 1), 1))
915+
elif assume_a == 'banded':
916+
A = np.triu(np.tril(A, 2), -1)
914917
elif assume_a in {'symmetric', 'sym'}:
915918
A = A + A.T
916919
elif assume_a in {'hermitian', 'her'}:
@@ -933,20 +936,22 @@ def test_structure_detection(self, dtype, assume_a, nrhs, transposed,
933936
return
934937

935938
res = solve(A, b, overwrite_a=overwrite, overwrite_b=overwrite,
936-
transposed=transposed)
939+
transposed=transposed, assume_a=assume_a)
937940

941+
# Check that solution this solution is *correct*
942+
ref = np.linalg.solve(A_copy.T if transposed else A_copy, b_copy)
943+
assert_allclose(res, ref)
944+
945+
# Check that `solve` correctly identifies the structure and returns
946+
# *exactly* the same solution whether `assume_a` is specified or not
947+
if assume_a != 'banded': # structure detection removed for banded
948+
assert_equal(solve(A_copy, b_copy, transposed=transposed), res)
949+
950+
# Check that overwrite was respected
938951
if not overwrite:
939952
assert_equal(A, A_copy)
940953
assert_equal(b, b_copy)
941954

942-
assume_a = 'sym' if assume_a in {'positive definite', 'pos'} else assume_a
943-
944-
ref = solve(A_copy, b_copy, assume_a=assume_a, transposed=transposed)
945-
assert_equal(res, ref)
946-
947-
ref = np.linalg.solve(A_copy.T if transposed else A_copy, b_copy)
948-
assert_allclose(res, ref)
949-
950955

951956
class TestSolveTriangular:
952957

0 commit comments

Comments
 (0)