13
13
from ._decomp import _asarray_validated
14
14
from . import _decomp , _decomp_svd
15
15
from ._solve_toeplitz import levinson
16
- from ._cythonized_array_utils import find_det_from_lu
16
+ from ._cythonized_array_utils import (find_det_from_lu , bandwidth , issymmetric ,
17
+ ishermitian )
17
18
18
19
__all__ = ['solve' , 'solve_triangular' , 'solveh_banded' , 'solve_banded' ,
19
20
'solve_toeplitz' , 'solve_circulant' , 'inv' , 'det' , 'lstsq' ,
@@ -48,8 +49,29 @@ def _solve_check(n, info, lamch=None, rcond=None):
48
49
LinAlgWarning , stacklevel = 3 )
49
50
50
51
52
+ def _find_matrix_structure (a ):
53
+ n = a .shape [0 ]
54
+ below , above = bandwidth (a )
55
+
56
+ if below == above == 0 :
57
+ return 'diagonal'
58
+ elif above == 0 :
59
+ return 'lower triangular'
60
+ elif below == 0 :
61
+ return 'upper triangular'
62
+ elif above <= 1 and below <= 1 and n > 3 :
63
+ return 'tridiagonal'
64
+
65
+ if np .issubdtype (a .dtype , np .complexfloating ) and ishermitian (a ):
66
+ return 'hermitian'
67
+ elif issymmetric (a ):
68
+ return 'symmetric'
69
+
70
+ return 'general'
71
+
72
+
51
73
def solve (a , b , lower = False , overwrite_a = False ,
52
- overwrite_b = False , check_finite = True , assume_a = 'gen' ,
74
+ overwrite_b = False , check_finite = True , assume_a = None ,
53
75
transposed = False ):
54
76
"""
55
77
Solves the linear equation set ``a @ x == b`` for the unknown ``x``
@@ -59,19 +81,16 @@ def solve(a, b, lower=False, overwrite_a=False,
59
81
corresponding string to ``assume_a`` key chooses the dedicated solver.
60
82
The available options are
61
83
62
- =================== ========
63
- generic matrix 'gen'
64
- symmetric 'sym'
65
- hermitian 'her'
66
- positive definite 'pos'
67
- =================== ========
68
-
69
- If omitted, ``'gen'`` is the default structure.
70
-
71
- The datatype of the arrays define which solver is called regardless
72
- of the values. In other words, even when the complex array entries have
73
- precisely zero imaginary parts, the complex solver will be called based
74
- on the data type of the array.
84
+ =================== ================================
85
+ diagonal 'diagonal'
86
+ tridiagonal 'tridiagonal'
87
+ upper triangular 'upper triangular'
88
+ lower triangular 'lower triangular'
89
+ symmetric 'symmetric' (or 'sym')
90
+ hermitian 'hermitian' (or 'her')
91
+ positive definite 'positive definite' (or 'pos')
92
+ general 'general' (or 'gen')
93
+ =================== ================================
75
94
76
95
Parameters
77
96
----------
@@ -80,8 +99,8 @@ def solve(a, b, lower=False, overwrite_a=False,
80
99
b : (N, NRHS) array_like
81
100
Input data for the right hand side.
82
101
lower : bool, default: False
83
- Ignored if ``assume_a == 'gen '`` (the default). If True, the
84
- calculation uses only the data in the lower triangle of `a`;
102
+ Ignored unless ``assume_a`` is one of ``'sym '``, ``'her'``, or ``'pos'``.
103
+ If True, the calculation uses only the data in the lower triangle of `a`;
85
104
entries above the diagonal are ignored. If False (default), the
86
105
calculation uses only the data in the upper triangle of `a`; entries
87
106
below the diagonal are ignored.
@@ -93,8 +112,10 @@ def solve(a, b, lower=False, overwrite_a=False,
93
112
Whether to check that the input matrices contain only finite numbers.
94
113
Disabling may give a performance gain, but may result in problems
95
114
(crashes, non-termination) if the inputs do contain infinities or NaNs.
96
- assume_a : str, {'gen', 'sym', 'her', 'pos'}
97
- Valid entries are explained above.
115
+ assume_a : str, optional
116
+ Valid entries are described above.
117
+ If omitted or ``None``, checks are performed to identify structure so the
118
+ appropriate solver can be called.
98
119
transposed : bool, default: False
99
120
If True, solve ``a.T @ x == b``. Raises `NotImplementedError`
100
121
for complex `a`.
@@ -122,10 +143,15 @@ def solve(a, b, lower=False, overwrite_a=False,
122
143
despite the apparent size mismatch. This is compatible with the
123
144
numpy.dot() behavior and the returned result is still 1-D array.
124
145
125
- The generic , symmetric, Hermitian and positive definite solutions are
146
+ The general , symmetric, Hermitian and positive definite solutions are
126
147
obtained via calling ?GESV, ?SYSV, ?HESV, and ?POSV routines of
127
148
LAPACK respectively.
128
149
150
+ The datatype of the arrays define which solver is called regardless
151
+ of the values. In other words, even when the complex array entries have
152
+ precisely zero imaginary parts, the complex solver will be called based
153
+ on the data type of the array.
154
+
129
155
Examples
130
156
--------
131
157
Given `a` and `b`, solve for `x`:
@@ -146,6 +172,7 @@ def solve(a, b, lower=False, overwrite_a=False,
146
172
147
173
a1 = atleast_2d (_asarray_validated (a , check_finite = check_finite ))
148
174
b1 = atleast_1d (_asarray_validated (b , check_finite = check_finite ))
175
+ a1 , b1 = _ensure_dtype_cdsz (a1 , b1 )
149
176
n = a1 .shape [0 ]
150
177
151
178
overwrite_a = overwrite_a or _datacopied (a1 , a )
@@ -173,13 +200,18 @@ def solve(a, b, lower=False, overwrite_a=False,
173
200
b1 = b1 [:, None ]
174
201
b_is_1D = True
175
202
176
- if assume_a not in ('gen' , 'sym' , 'her' , 'pos' ):
203
+ if assume_a not in {None , 'diagonal' , 'tridiagonal' , 'lower triangular' ,
204
+ 'upper triangular' , 'symmetric' , 'hermitian' ,
205
+ 'positive definite' , 'general' , 'sym' , 'her' , 'pos' , 'gen' }:
177
206
raise ValueError (f'{ assume_a } is not a recognized matrix structure' )
178
207
179
208
# for a real matrix, describe it as "symmetric", not "hermitian"
180
209
# (lapack doesn't know what to do with real hermitian matrices)
181
- if assume_a == 'her' and not np .iscomplexobj (a1 ):
182
- assume_a = 'sym'
210
+ if assume_a in {'hermitian' , 'her' } and not np .iscomplexobj (a1 ):
211
+ assume_a = 'symmetric'
212
+
213
+ if assume_a is None :
214
+ assume_a = _find_matrix_structure (a1 )
183
215
184
216
# Get the correct lamch function.
185
217
# The LAMCH functions only exists for S and D
@@ -192,7 +224,12 @@ def solve(a, b, lower=False, overwrite_a=False,
192
224
# Currently we do not have the other forms of the norm calculators
193
225
# lansy, lanpo, lanhe.
194
226
# However, in any case they only reduce computations slightly...
195
- lange = get_lapack_funcs ('lange' , (a1 ,))
227
+ if assume_a == 'diagonal' :
228
+ lange = _lange_diagonal
229
+ elif assume_a == 'tridiagonal' :
230
+ lange = _lange_tridiagonal
231
+ else :
232
+ lange = get_lapack_funcs ('lange' , (a1 ,))
196
233
197
234
# Since the I-norm and 1-norm are the same for symmetric matrices
198
235
# we can collect them all in this one call
@@ -211,8 +248,10 @@ def solve(a, b, lower=False, overwrite_a=False,
211
248
212
249
anorm = lange (norm , a1 )
213
250
251
+ info , rcond = 0 , np .inf
252
+
214
253
# Generalized case 'gesv'
215
- if assume_a == ' gen' :
254
+ if assume_a in { 'general' , ' gen'} :
216
255
gecon , getrf , getrs = get_lapack_funcs (('gecon' , 'getrf' , 'getrs' ),
217
256
(a1 , b1 ))
218
257
lu , ipvt , info = getrf (a1 , overwrite_a = overwrite_a )
@@ -222,7 +261,7 @@ def solve(a, b, lower=False, overwrite_a=False,
222
261
_solve_check (n , info )
223
262
rcond , info = gecon (lu , anorm , norm = norm )
224
263
# Hermitian case 'hesv'
225
- elif assume_a == ' her' :
264
+ elif assume_a in { 'hermitian' , ' her'} :
226
265
hecon , hesv , hesv_lw = get_lapack_funcs (('hecon' , 'hesv' ,
227
266
'hesv_lwork' ), (a1 , b1 ))
228
267
lwork = _compute_lwork (hesv_lw , n , lower )
@@ -233,7 +272,7 @@ def solve(a, b, lower=False, overwrite_a=False,
233
272
_solve_check (n , info )
234
273
rcond , info = hecon (lu , ipvt , anorm )
235
274
# Symmetric case 'sysv'
236
- elif assume_a == ' sym' :
275
+ elif assume_a in { 'symmetric' , ' sym'} :
237
276
sycon , sysv , sysv_lw = get_lapack_funcs (('sycon' , 'sysv' ,
238
277
'sysv_lwork' ), (a1 , b1 ))
239
278
lwork = _compute_lwork (sysv_lw , n , lower )
@@ -243,6 +282,23 @@ def solve(a, b, lower=False, overwrite_a=False,
243
282
overwrite_b = overwrite_b )
244
283
_solve_check (n , info )
245
284
rcond , info = sycon (lu , ipvt , anorm )
285
+ # Diagonal case
286
+ elif assume_a == 'diagonal' :
287
+ diag_a = np .diag (a1 )
288
+ x = (b1 .T / diag_a ).T
289
+ abs_diag_a = np .abs (diag_a )
290
+ rcond = abs_diag_a .min () / abs_diag_a .max ()
291
+ # Tri-diagonal case
292
+ elif assume_a == 'tridiagonal' :
293
+ a1 = a1 .T if transposed else a1
294
+ dl , d , du = np .diag (a1 , - 1 ), np .diag (a1 , 0 ), np .diag (a1 , 1 )
295
+ _gtsv = get_lapack_funcs ('gtsv' , (a1 , b1 ))
296
+ x , info = _gtsv (dl , d , du , b1 , False , False , False , overwrite_b )[3 :]
297
+ # Triangular case
298
+ elif assume_a in {'lower triangular' , 'upper triangular' }:
299
+ lower = assume_a == 'lower triangular'
300
+ x = _solve_triangular (a1 , b1 , lower = lower , overwrite_b = overwrite_b ,
301
+ trans = transposed )
246
302
# Positive definite case 'posv'
247
303
else :
248
304
pocon , posv = get_lapack_funcs (('pocon' , 'posv' ),
@@ -261,6 +317,38 @@ def solve(a, b, lower=False, overwrite_a=False,
261
317
return x
262
318
263
319
320
+ def _lange_diagonal (_ , a ):
321
+ # Equivalent of dlange for diagonal matrix, assuming
322
+ # norm is either 'I' or '1' (really just not the Frobenius norm)
323
+ return np .abs (np .diag (a )).max ()
324
+
325
+
326
+ def _lange_tridiagonal (norm , a ):
327
+ # Equivalent of dlange for tridiagonal matrix, assuming
328
+ # norm is either 'I' or '1'
329
+ if norm == 'I' :
330
+ a = a .T
331
+ d = np .abs (np .diag (a ))
332
+ d [1 :] += np .abs (np .diag (a , 1 ))
333
+ d [:- 1 ] += np .abs (np .diag (a , - 1 ))
334
+ return d .max ()
335
+
336
+
337
+ def _ensure_dtype_cdsz (* arrays ):
338
+ # Ensure that the dtype of arrays is one of the standard types
339
+ # compatible with LAPACK functions (single or double precision
340
+ # real or complex).
341
+ dtype = np .result_type (* arrays )
342
+ if not np .issubdtype (dtype , np .inexact ):
343
+ return (array .astype (np .float64 ) for array in arrays )
344
+ complex = np .issubdtype (dtype , np .complexfloating )
345
+ if np .finfo (dtype ).bits <= 32 :
346
+ dtype = np .complex64 if complex else np .float32
347
+ elif np .finfo (dtype ).bits >= 64 :
348
+ dtype = np .complex128 if complex else np .float64
349
+ return (array .astype (dtype , copy = False ) for array in arrays )
350
+
351
+
264
352
def solve_triangular (a , b , trans = 0 , lower = False , unit_diagonal = False ,
265
353
overwrite_b = False , check_finite = True ):
266
354
"""
@@ -348,6 +436,13 @@ def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
348
436
349
437
overwrite_b = overwrite_b or _datacopied (b1 , b )
350
438
439
+ return _solve_triangular (a1 , b1 , trans , lower , unit_diagonal , overwrite_b )
440
+
441
+
442
+ # solve_triangular without the input validation
443
+ def _solve_triangular (a1 , b1 , trans = 0 , lower = False , unit_diagonal = False ,
444
+ overwrite_b = False ):
445
+
351
446
trans = {'N' : 0 , 'T' : 1 , 'C' : 2 }.get (trans , trans )
352
447
trtrs , = get_lapack_funcs (('trtrs' ,), (a1 , b1 ))
353
448
if a1 .flags .f_contiguous or trans == 2 :
0 commit comments