@@ -51,23 +51,24 @@ def _solve_check(n, info, lamch=None, rcond=None):
51
51
52
52
def _find_matrix_structure (a ):
53
53
n = a .shape [0 ]
54
- below , above = bandwidth (a )
55
-
56
- if below == above == 0 :
57
- return 'diagonal'
58
- elif above == 0 :
59
- return 'lower triangular'
60
- elif below == 0 :
61
- return 'upper triangular'
62
- elif above <= 1 and below <= 1 and n > 3 :
63
- return 'tridiagonal'
64
-
65
- if np .issubdtype (a .dtype , np .complexfloating ) and ishermitian (a ):
66
- return 'hermitian'
54
+ n_below , n_above = bandwidth (a )
55
+
56
+ if n_below == n_above == 0 :
57
+ kind = 'diagonal'
58
+ elif n_above == 0 :
59
+ kind = 'lower triangular'
60
+ elif n_below == 0 :
61
+ kind = 'upper triangular'
62
+ elif n_above <= 1 and n_below <= 1 and n > 3 :
63
+ kind = 'tridiagonal'
64
+ elif np .issubdtype (a .dtype , np .complexfloating ) and ishermitian (a ):
65
+ kind = 'hermitian'
67
66
elif issymmetric (a ):
68
- return 'symmetric'
67
+ kind = 'symmetric'
68
+ else :
69
+ kind = 'general'
69
70
70
- return 'general'
71
+ return kind , n_below , n_above
71
72
72
73
73
74
def solve (a , b , lower = False , overwrite_a = False ,
@@ -84,6 +85,7 @@ def solve(a, b, lower=False, overwrite_a=False,
84
85
=================== ================================
85
86
diagonal 'diagonal'
86
87
tridiagonal 'tridiagonal'
88
+ banded 'banded'
87
89
upper triangular 'upper triangular'
88
90
lower triangular 'lower triangular'
89
91
symmetric 'symmetric' (or 'sym')
@@ -201,7 +203,7 @@ def solve(a, b, lower=False, overwrite_a=False,
201
203
b1 = b1 [:, None ]
202
204
b_is_1D = True
203
205
204
- if assume_a not in {None , 'diagonal' , 'tridiagonal' , 'lower triangular' ,
206
+ if assume_a not in {None , 'diagonal' , 'tridiagonal' , 'banded' , ' lower triangular' ,
205
207
'upper triangular' , 'symmetric' , 'hermitian' ,
206
208
'positive definite' , 'general' , 'sym' , 'her' , 'pos' , 'gen' }:
207
209
raise ValueError (f'{ assume_a } is not a recognized matrix structure' )
@@ -211,8 +213,9 @@ def solve(a, b, lower=False, overwrite_a=False,
211
213
if assume_a in {'hermitian' , 'her' } and not np .iscomplexobj (a1 ):
212
214
assume_a = 'symmetric'
213
215
216
+ n_below , n_above = None , None
214
217
if assume_a is None :
215
- assume_a = _find_matrix_structure (a1 )
218
+ assume_a , n_below , n_above = _find_matrix_structure (a1 )
216
219
217
220
# Get the correct lamch function.
218
221
# The LAMCH functions only exists for S and D
@@ -301,6 +304,20 @@ def solve(a, b, lower=False, overwrite_a=False,
301
304
x , info = _gttrs (dl , d , du , du2 , ipiv , b1 , overwrite_b = overwrite_b )
302
305
_solve_check (n , info )
303
306
rcond , info = _gtcon (dl , d , du , du2 , ipiv , anorm )
307
+ # Banded case
308
+ elif assume_a == 'banded' :
309
+ a1 , n_below , n_above = ((a1 .T , n_above , n_below ) if transposed
310
+ else (a1 , n_below , n_above ))
311
+ n_below , n_above = bandwidth (a1 ) if n_below is None else (n_below , n_above )
312
+ ab = _to_banded (n_below , n_above , a1 )
313
+ gbsv , = get_lapack_funcs (('gbsv' ,), (a1 , b1 ))
314
+ # Next two lines copied from `solve_banded`
315
+ a2 = np .zeros ((2 * n_below + n_above + 1 , ab .shape [1 ]), dtype = gbsv .dtype )
316
+ a2 [n_below :, :] = ab
317
+ _ , _ , x , info = gbsv (n_below , n_above , a2 , b1 ,
318
+ overwrite_ab = True , overwrite_b = overwrite_b )
319
+ _solve_check (n , info )
320
+ # TODO: wrap gbcon and use to get rcond
304
321
# Triangular case
305
322
elif assume_a in {'lower triangular' , 'upper triangular' }:
306
323
lower = assume_a == 'lower triangular'
@@ -363,6 +380,18 @@ def _matrix_norm_general(norm, a, check_finite):
363
380
return lange (norm , a )
364
381
365
382
383
+ def _to_banded (n_below , n_above , a ):
384
+ n = a .shape [0 ]
385
+ rows = n_above + n_below + 1
386
+ ab = np .zeros ((rows , n ), dtype = a .dtype )
387
+ ab [n_above ] = np .diag (a )
388
+ for i in range (1 , n_above + 1 ):
389
+ ab [n_above - i , i :] = np .diag (a , i )
390
+ for i in range (1 , n_below + 1 ):
391
+ ab [n_above + i , :- i ] = np .diag (a , - i )
392
+ return ab
393
+
394
+
366
395
def _ensure_dtype_cdsz (* arrays ):
367
396
# Ensure that the dtype of arrays is one of the standard types
368
397
# compatible with LAPACK functions (single or double precision
0 commit comments