Skip to content

Commit 74b21b6

Browse files
authored
Merge branch 'main' into gh20048
2 parents 3939dba + 33c5bfa commit 74b21b6

File tree

4 files changed

+208
-30
lines changed

4 files changed

+208
-30
lines changed

scipy/linalg/_basic.py

Lines changed: 122 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from ._decomp import _asarray_validated
1414
from . import _decomp, _decomp_svd
1515
from ._solve_toeplitz import levinson
16-
from ._cythonized_array_utils import find_det_from_lu
16+
from ._cythonized_array_utils import (find_det_from_lu, bandwidth, issymmetric,
17+
ishermitian)
1718

1819
__all__ = ['solve', 'solve_triangular', 'solveh_banded', 'solve_banded',
1920
'solve_toeplitz', 'solve_circulant', 'inv', 'det', 'lstsq',
@@ -48,8 +49,29 @@ def _solve_check(n, info, lamch=None, rcond=None):
4849
LinAlgWarning, stacklevel=3)
4950

5051

52+
def _find_matrix_structure(a):
53+
n = a.shape[0]
54+
below, above = bandwidth(a)
55+
56+
if below == above == 0:
57+
return 'diagonal'
58+
elif above == 0:
59+
return 'lower triangular'
60+
elif below == 0:
61+
return 'upper triangular'
62+
elif above <= 1 and below <= 1 and n > 3:
63+
return 'tridiagonal'
64+
65+
if np.issubdtype(a.dtype, np.complexfloating) and ishermitian(a):
66+
return 'hermitian'
67+
elif issymmetric(a):
68+
return 'symmetric'
69+
70+
return 'general'
71+
72+
5173
def solve(a, b, lower=False, overwrite_a=False,
52-
overwrite_b=False, check_finite=True, assume_a='gen',
74+
overwrite_b=False, check_finite=True, assume_a=None,
5375
transposed=False):
5476
"""
5577
Solves the linear equation set ``a @ x == b`` for the unknown ``x``
@@ -59,19 +81,16 @@ def solve(a, b, lower=False, overwrite_a=False,
5981
corresponding string to ``assume_a`` key chooses the dedicated solver.
6082
The available options are
6183
62-
=================== ========
63-
generic matrix 'gen'
64-
symmetric 'sym'
65-
hermitian 'her'
66-
positive definite 'pos'
67-
=================== ========
68-
69-
If omitted, ``'gen'`` is the default structure.
70-
71-
The datatype of the arrays define which solver is called regardless
72-
of the values. In other words, even when the complex array entries have
73-
precisely zero imaginary parts, the complex solver will be called based
74-
on the data type of the array.
84+
=================== ================================
85+
diagonal 'diagonal'
86+
tridiagonal 'tridiagonal'
87+
upper triangular 'upper triangular'
88+
lower triangular 'lower triangular'
89+
symmetric 'symmetric' (or 'sym')
90+
hermitian 'hermitian' (or 'her')
91+
positive definite 'positive definite' (or 'pos')
92+
general 'general' (or 'gen')
93+
=================== ================================
7594
7695
Parameters
7796
----------
@@ -80,8 +99,8 @@ def solve(a, b, lower=False, overwrite_a=False,
8099
b : (N, NRHS) array_like
81100
Input data for the right hand side.
82101
lower : bool, default: False
83-
Ignored if ``assume_a == 'gen'`` (the default). If True, the
84-
calculation uses only the data in the lower triangle of `a`;
102+
Ignored unless ``assume_a`` is one of ``'sym'``, ``'her'``, or ``'pos'``.
103+
If True, the calculation uses only the data in the lower triangle of `a`;
85104
entries above the diagonal are ignored. If False (default), the
86105
calculation uses only the data in the upper triangle of `a`; entries
87106
below the diagonal are ignored.
@@ -93,8 +112,10 @@ def solve(a, b, lower=False, overwrite_a=False,
93112
Whether to check that the input matrices contain only finite numbers.
94113
Disabling may give a performance gain, but may result in problems
95114
(crashes, non-termination) if the inputs do contain infinities or NaNs.
96-
assume_a : str, {'gen', 'sym', 'her', 'pos'}
97-
Valid entries are explained above.
115+
assume_a : str, optional
116+
Valid entries are described above.
117+
If omitted or ``None``, checks are performed to identify structure so the
118+
appropriate solver can be called.
98119
transposed : bool, default: False
99120
If True, solve ``a.T @ x == b``. Raises `NotImplementedError`
100121
for complex `a`.
@@ -122,10 +143,15 @@ def solve(a, b, lower=False, overwrite_a=False,
122143
despite the apparent size mismatch. This is compatible with the
123144
numpy.dot() behavior and the returned result is still 1-D array.
124145
125-
The generic, symmetric, Hermitian and positive definite solutions are
146+
The general, symmetric, Hermitian and positive definite solutions are
126147
obtained via calling ?GESV, ?SYSV, ?HESV, and ?POSV routines of
127148
LAPACK respectively.
128149
150+
The datatype of the arrays define which solver is called regardless
151+
of the values. In other words, even when the complex array entries have
152+
precisely zero imaginary parts, the complex solver will be called based
153+
on the data type of the array.
154+
129155
Examples
130156
--------
131157
Given `a` and `b`, solve for `x`:
@@ -146,6 +172,7 @@ def solve(a, b, lower=False, overwrite_a=False,
146172

147173
a1 = atleast_2d(_asarray_validated(a, check_finite=check_finite))
148174
b1 = atleast_1d(_asarray_validated(b, check_finite=check_finite))
175+
a1, b1 = _ensure_dtype_cdsz(a1, b1)
149176
n = a1.shape[0]
150177

151178
overwrite_a = overwrite_a or _datacopied(a1, a)
@@ -173,13 +200,18 @@ def solve(a, b, lower=False, overwrite_a=False,
173200
b1 = b1[:, None]
174201
b_is_1D = True
175202

176-
if assume_a not in ('gen', 'sym', 'her', 'pos'):
203+
if assume_a not in {None, 'diagonal', 'tridiagonal', 'lower triangular',
204+
'upper triangular', 'symmetric', 'hermitian',
205+
'positive definite', 'general', 'sym', 'her', 'pos', 'gen'}:
177206
raise ValueError(f'{assume_a} is not a recognized matrix structure')
178207

179208
# for a real matrix, describe it as "symmetric", not "hermitian"
180209
# (lapack doesn't know what to do with real hermitian matrices)
181-
if assume_a == 'her' and not np.iscomplexobj(a1):
182-
assume_a = 'sym'
210+
if assume_a in {'hermitian', 'her'} and not np.iscomplexobj(a1):
211+
assume_a = 'symmetric'
212+
213+
if assume_a is None:
214+
assume_a = _find_matrix_structure(a1)
183215

184216
# Get the correct lamch function.
185217
# The LAMCH functions only exists for S and D
@@ -192,7 +224,12 @@ def solve(a, b, lower=False, overwrite_a=False,
192224
# Currently we do not have the other forms of the norm calculators
193225
# lansy, lanpo, lanhe.
194226
# However, in any case they only reduce computations slightly...
195-
lange = get_lapack_funcs('lange', (a1,))
227+
if assume_a == 'diagonal':
228+
lange = _lange_diagonal
229+
elif assume_a == 'tridiagonal':
230+
lange = _lange_tridiagonal
231+
else:
232+
lange = get_lapack_funcs('lange', (a1,))
196233

197234
# Since the I-norm and 1-norm are the same for symmetric matrices
198235
# we can collect them all in this one call
@@ -211,8 +248,10 @@ def solve(a, b, lower=False, overwrite_a=False,
211248

212249
anorm = lange(norm, a1)
213250

251+
info, rcond = 0, np.inf
252+
214253
# Generalized case 'gesv'
215-
if assume_a == 'gen':
254+
if assume_a in {'general', 'gen'}:
216255
gecon, getrf, getrs = get_lapack_funcs(('gecon', 'getrf', 'getrs'),
217256
(a1, b1))
218257
lu, ipvt, info = getrf(a1, overwrite_a=overwrite_a)
@@ -222,7 +261,7 @@ def solve(a, b, lower=False, overwrite_a=False,
222261
_solve_check(n, info)
223262
rcond, info = gecon(lu, anorm, norm=norm)
224263
# Hermitian case 'hesv'
225-
elif assume_a == 'her':
264+
elif assume_a in {'hermitian', 'her'}:
226265
hecon, hesv, hesv_lw = get_lapack_funcs(('hecon', 'hesv',
227266
'hesv_lwork'), (a1, b1))
228267
lwork = _compute_lwork(hesv_lw, n, lower)
@@ -233,7 +272,7 @@ def solve(a, b, lower=False, overwrite_a=False,
233272
_solve_check(n, info)
234273
rcond, info = hecon(lu, ipvt, anorm)
235274
# Symmetric case 'sysv'
236-
elif assume_a == 'sym':
275+
elif assume_a in {'symmetric', 'sym'}:
237276
sycon, sysv, sysv_lw = get_lapack_funcs(('sycon', 'sysv',
238277
'sysv_lwork'), (a1, b1))
239278
lwork = _compute_lwork(sysv_lw, n, lower)
@@ -243,6 +282,23 @@ def solve(a, b, lower=False, overwrite_a=False,
243282
overwrite_b=overwrite_b)
244283
_solve_check(n, info)
245284
rcond, info = sycon(lu, ipvt, anorm)
285+
# Diagonal case
286+
elif assume_a == 'diagonal':
287+
diag_a = np.diag(a1)
288+
x = (b1.T / diag_a).T
289+
abs_diag_a = np.abs(diag_a)
290+
rcond = abs_diag_a.min() / abs_diag_a.max()
291+
# Tri-diagonal case
292+
elif assume_a == 'tridiagonal':
293+
a1 = a1.T if transposed else a1
294+
dl, d, du = np.diag(a1, -1), np.diag(a1, 0), np.diag(a1, 1)
295+
_gtsv = get_lapack_funcs('gtsv', (a1, b1))
296+
x, info = _gtsv(dl, d, du, b1, False, False, False, overwrite_b)[3:]
297+
# Triangular case
298+
elif assume_a in {'lower triangular', 'upper triangular'}:
299+
lower = assume_a == 'lower triangular'
300+
x = _solve_triangular(a1, b1, lower=lower, overwrite_b=overwrite_b,
301+
trans=transposed)
246302
# Positive definite case 'posv'
247303
else:
248304
pocon, posv = get_lapack_funcs(('pocon', 'posv'),
@@ -261,6 +317,38 @@ def solve(a, b, lower=False, overwrite_a=False,
261317
return x
262318

263319

320+
def _lange_diagonal(_, a):
321+
# Equivalent of dlange for diagonal matrix, assuming
322+
# norm is either 'I' or '1' (really just not the Frobenius norm)
323+
return np.abs(np.diag(a)).max()
324+
325+
326+
def _lange_tridiagonal(norm, a):
327+
# Equivalent of dlange for tridiagonal matrix, assuming
328+
# norm is either 'I' or '1'
329+
if norm == 'I':
330+
a = a.T
331+
d = np.abs(np.diag(a))
332+
d[1:] += np.abs(np.diag(a, 1))
333+
d[:-1] += np.abs(np.diag(a, -1))
334+
return d.max()
335+
336+
337+
def _ensure_dtype_cdsz(*arrays):
338+
# Ensure that the dtype of arrays is one of the standard types
339+
# compatible with LAPACK functions (single or double precision
340+
# real or complex).
341+
dtype = np.result_type(*arrays)
342+
if not np.issubdtype(dtype, np.inexact):
343+
return (array.astype(np.float64) for array in arrays)
344+
complex = np.issubdtype(dtype, np.complexfloating)
345+
if np.finfo(dtype).bits <= 32:
346+
dtype = np.complex64 if complex else np.float32
347+
elif np.finfo(dtype).bits >= 64:
348+
dtype = np.complex128 if complex else np.float64
349+
return (array.astype(dtype, copy=False) for array in arrays)
350+
351+
264352
def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
265353
overwrite_b=False, check_finite=True):
266354
"""
@@ -348,6 +436,13 @@ def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
348436

349437
overwrite_b = overwrite_b or _datacopied(b1, b)
350438

439+
return _solve_triangular(a1, b1, trans, lower, unit_diagonal, overwrite_b)
440+
441+
442+
# solve_triangular without the input validation
443+
def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False,
444+
overwrite_b=False):
445+
351446
trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans)
352447
trtrs, = get_lapack_funcs(('trtrs',), (a1, b1))
353448
if a1.flags.f_contiguous or trans == 2:

scipy/linalg/tests/test_basic.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -763,8 +763,11 @@ def test_singularity(self):
763763
assert_raises(LinAlgError, solve, a, b)
764764

765765
def test_ill_condition_warning(self):
766-
a = np.array([[1, 1], [1+1e-16, 1-1e-16]])
767-
b = np.ones(2)
766+
a = np.array([[1, 1, 1],
767+
[1+1e-16, 1-1e-16, 1],
768+
[1-1e-16, 1+1e-16, 1],
769+
])
770+
b = np.ones(3)
768771
with warnings.catch_warnings():
769772
warnings.simplefilter('error')
770773
assert_raises(LinAlgWarning, solve, a, b)
@@ -864,6 +867,72 @@ def test_empty_rhs(self):
864867
assert_(x.size == 0, 'Returned array is not empty')
865868
assert_(x.shape == (2, 0), 'Returned empty array shape is wrong')
866869

870+
@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
871+
@pytest.mark.parametrize('assume_a', ['diagonal', 'tridiagonal', 'lower triangular',
872+
'upper triangular', 'symmetric', 'hermitian',
873+
'positive definite', 'general',
874+
'sym', 'her', 'pos', 'gen'])
875+
@pytest.mark.parametrize('nrhs', [(), (5,)])
876+
@pytest.mark.parametrize('transposed', [True, False])
877+
@pytest.mark.parametrize('overwrite', [True, False])
878+
@pytest.mark.parametrize('fortran', [True, False])
879+
def test_structure_detection(self, dtype, assume_a, nrhs, transposed,
880+
overwrite, fortran):
881+
rng = np.random.default_rng(982345982439826)
882+
n = 5
883+
b = rng.random(size=(n,) + nrhs)
884+
A = rng.random(size=(n, n))
885+
886+
if np.issubdtype(dtype, np.complexfloating):
887+
b = b + rng.random(size=(n,) + nrhs) * 1j
888+
A = A + rng.random(size=(n, n)) * 1j
889+
890+
if assume_a == 'diagonal':
891+
A = np.diag(np.diag(A))
892+
elif assume_a == 'lower triangular':
893+
A = np.tril(A)
894+
elif assume_a == 'upper triangular':
895+
A = np.triu(A)
896+
elif assume_a == 'tridiagonal':
897+
A = (np.diag(np.diag(A))
898+
+ np.diag(np.diag(A, -1), -1)
899+
+ np.diag(np.diag(A, 1), 1))
900+
elif assume_a in {'symmetric', 'sym'}:
901+
A = A + A.T
902+
elif assume_a in {'hermitian', 'her'}:
903+
A = A + A.conj().T
904+
elif assume_a in {'positive definite', 'pos'}:
905+
A = A + A.T
906+
A += np.diag(A.sum(axis=1))
907+
908+
if fortran:
909+
A = np.asfortranarray(A)
910+
911+
A_copy = A.copy(order='A')
912+
b_copy = b.copy()
913+
914+
if np.issubdtype(dtype, np.complexfloating) and transposed:
915+
message = "scipy.linalg.solve can currently..."
916+
with pytest.raises(NotImplementedError, match=message):
917+
solve(A, b, overwrite_a=overwrite, overwrite_b=overwrite,
918+
transposed=transposed)
919+
return
920+
921+
res = solve(A, b, overwrite_a=overwrite, overwrite_b=overwrite,
922+
transposed=transposed)
923+
924+
if not overwrite:
925+
assert_equal(A, A_copy)
926+
assert_equal(b, b_copy)
927+
928+
assume_a = 'sym' if assume_a in {'positive definite', 'pos'} else assume_a
929+
930+
ref = solve(A_copy, b_copy, assume_a=assume_a, transposed=transposed)
931+
assert_equal(res, ref)
932+
933+
ref = np.linalg.solve(A_copy.T if transposed else A_copy, b_copy)
934+
assert_allclose(res, ref)
935+
867936

868937
class TestSolveTriangular:
869938

scipy/stats/_discrete_distns.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1209,7 +1209,7 @@ def _get_support(self, low, high):
12091209

12101210
def _pmf(self, k, low, high):
12111211
# randint.pmf(k) = 1./(high - low)
1212-
p = np.ones_like(k) / (high - low)
1212+
p = np.ones_like(k) / (np.asarray(high, dtype=np.int64) - low)
12131213
return np.where((k >= low) & (k < high), p, 0.)
12141214

12151215
def _cdf(self, x, low, high):

scipy/stats/tests/test_discrete_distns.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,8 @@ def test_gh20692(self):
650650

651651

652652
def test_gh20048():
653+
# gh-20048 reported an infinite loop in _drv2_ppfsingle
654+
# check that the one identified is resolved
653655
class test_dist_gen(stats.rv_discrete):
654656
def _cdf(self, k):
655657
return min(k / 100, 0.99)
@@ -659,3 +661,15 @@ def _cdf(self, k):
659661
message = "Arguments that bracket..."
660662
with pytest.raises(RuntimeError, match=message):
661663
test_dist.ppf(0.999)
664+
665+
666+
class TestRandInt:
667+
def test_gh19759(self):
668+
# test zero PMF values within the support reported by gh-19759
669+
a = -354
670+
max_range = abs(a)
671+
all_b_1 = [a + 2 ** 31 + i for i in range(max_range)]
672+
res = randint.pmf(325, a, all_b_1)
673+
assert (res > 0).all()
674+
ref = 1 / (np.asarray(all_b_1, dtype=np.float64) - a)
675+
assert_allclose(res, ref)

0 commit comments

Comments
 (0)