Skip to content

Commit bb44a07

Browse files
authored
ENH: linalg: add batch support for functions that accept a single array (scipy#22133)
* ENH: linalg.diagsvd: add batch support * ENH: linalg.inv: add batch support * ENH: linalg.null_space: add batch support * ENH: linalg.sqrtm: add batch support * ENH: linalg.funm: add batch support * ENH: linalg.signm: add batch support * ENH: linalg.fractional_matrix_power: add batch support * ENH: linalg: add batch support to several functions * MAINT: linalg.logm: remove warnings filter from test * ENH: linalg.pinv/pinvh: add batch support * ENH: linalg.matrix_balance: add batch support * TST: linalg: test passing arrays by keyword * ENH: linalg.bandwidth: add batch support * ENH: linalg.ldl: add batch support * ENH: linalg.svd: add batch support * ENH: linalg.cholesky: add batch support * ENH: linalg.polar: add batch support * ENH: linalg: add batch support to qr, rq, schur, hessenberg * MAINT: linalg.clarkson_woodruff_transform: add batch support * ENH: linalg.orth/lu_factor: add batch support * ENH: linalg.eig_banded/eigvals_banded: add batch support * TST: linalg.bandwidth: no more error with 3D input * MAINT: linalg.cwt: remove batch support to avoid issues with sparse input
1 parent f6b23a9 commit bb44a07

14 files changed

+209
-22
lines changed

scipy/linalg/_basic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from itertools import product
99
import numpy as np
1010
from numpy import atleast_1d, atleast_2d
11+
from scipy._lib._util import _apply_over_batch
1112
from .lapack import get_lapack_funcs, _compute_lwork
1213
from ._misc import LinAlgError, _datacopied, LinAlgWarning
1314
from ._decomp import _asarray_validated
@@ -1089,6 +1090,7 @@ def solve_circulant(c, b, singular='raise', tol=None,
10891090

10901091

10911092
# matrix inversion
1093+
@_apply_over_batch(('a', 2))
10921094
def inv(a, overwrite_a=False, check_finite=True):
10931095
"""
10941096
Compute the inverse of a matrix.
@@ -1499,6 +1501,7 @@ def lstsq(a, b, cond=None, overwrite_a=False, overwrite_b=False,
14991501
lstsq.default_lapack_driver = 'gelsd'
15001502

15011503

1504+
@_apply_over_batch(('a', 2))
15021505
def pinv(a, *, atol=None, rtol=None, return_rank=False, check_finite=True):
15031506
"""
15041507
Compute the (Moore-Penrose) pseudo-inverse of a matrix.
@@ -1622,6 +1625,7 @@ def pinv(a, *, atol=None, rtol=None, return_rank=False, check_finite=True):
16221625
return B
16231626

16241627

1628+
@_apply_over_batch(('a', 2))
16251629
def pinvh(a, atol=None, rtol=None, lower=True, return_rank=False,
16261630
check_finite=True):
16271631
"""
@@ -1715,6 +1719,7 @@ def pinvh(a, atol=None, rtol=None, lower=True, return_rank=False,
17151719
return B
17161720

17171721

1722+
@_apply_over_batch(('A', 2))
17181723
def matrix_balance(A, permute=True, scale=True, separate=False,
17191724
overwrite_a=False):
17201725
"""

scipy/linalg/_cythonized_array_utils.pyx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ cdef inline void swap_c_and_f_layout(lapack_t *a, lapack_t *b, int r, int c) noe
9090
# ============================================================================
9191

9292

93+
@_apply_over_batch(('a', 2))
9394
@cython.embedsignature(True)
9495
def bandwidth(a):
9596
"""Return the lower and upper bandwidth of a 2D numeric array.

scipy/linalg/_decomp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,7 @@ def _check_select(select, select_range, max_ev, max_len):
657657
return select, vl, vu, il, iu, max_ev
658658

659659

660+
@_apply_over_batch(('a_band', 2))
660661
def eig_banded(a_band, lower=False, eigvals_only=False, overwrite_a_band=False,
661662
select='a', select_range=None, max_ev=0, check_finite=True):
662663
"""
@@ -1029,6 +1030,7 @@ def eigvalsh(a, b=None, *, lower=True, overwrite_a=False,
10291030
driver=driver)
10301031

10311032

1033+
@_apply_over_batch(('a_band', 2))
10321034
def eigvals_banded(a_band, lower=False, overwrite_a_band=False,
10331035
select='a', select_range=None, check_finite=True):
10341036
"""
@@ -1391,6 +1393,7 @@ def _check_info(info, driver, positive='did not converge (LAPACK info=%d)'):
13911393
raise LinAlgError(("%s " + positive) % (driver, info,))
13921394

13931395

1396+
@_apply_over_batch(('a', 2))
13941397
def hessenberg(a, calc_q=False, overwrite_a=False, check_finite=True):
13951398
"""
13961399
Compute Hessenberg form of a matrix.

scipy/linalg/_decomp_cholesky.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from numpy import asarray_chkfinite, asarray, atleast_2d, empty_like
55

66
# Local imports
7+
from scipy._lib._util import _apply_over_batch
78
from ._misc import LinAlgError, _datacopied
89
from .lapack import get_lapack_funcs
910

@@ -43,6 +44,7 @@ def _cholesky(a, lower=False, overwrite_a=False, clean=True,
4344
return c, lower
4445

4546

47+
@_apply_over_batch(('a', 2))
4648
def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
4749
"""
4850
Compute the Cholesky decomposition of a matrix.

scipy/linalg/_decomp_ldl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import numpy as np
44
from numpy import (atleast_2d, arange, zeros_like, imag, diag,
55
iscomplexobj, tril, triu, argsort, empty_like)
6-
from scipy._lib._util import ComplexWarning
6+
from scipy._lib._util import ComplexWarning, _apply_over_batch
77
from ._decomp import _asarray_validated
88
from .lapack import get_lapack_funcs, _compute_lwork
99

1010
__all__ = ['ldl']
1111

1212

13+
@_apply_over_batch(('A', 2))
1314
def ldl(A, lower=True, hermitian=True, overwrite_a=False, check_finite=True):
1415
""" Computes the LDLt or Bunch-Kaufman factorization of a symmetric/
1516
hermitian matrix.

scipy/linalg/_decomp_lu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import numpy as np
77
from itertools import product
88

9+
from scipy._lib._util import _apply_over_batch
10+
911
# Local imports
1012
from ._misc import _datacopied, LinAlgWarning
1113
from .lapack import get_lapack_funcs
@@ -17,6 +19,7 @@
1719
__all__ = ['lu', 'lu_solve', 'lu_factor']
1820

1921

22+
@_apply_over_batch(('a', 2))
2023
def lu_factor(a, overwrite_a=False, check_finite=True):
2124
"""
2225
Compute pivoted LU decomposition of a matrix.

scipy/linalg/_decomp_polar.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import numpy as np
2+
from scipy._lib._util import _apply_over_batch
23
from scipy.linalg import svd
34

45

56
__all__ = ['polar']
67

78

9+
@_apply_over_batch(('a', 2))
810
def polar(a, side="right"):
911
"""
1012
Compute the polar decomposition.

scipy/linalg/_decomp_qr.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""QR decomposition functions."""
22
import numpy as np
33

4+
from scipy._lib._util import _apply_over_batch
5+
46
# Local imports
57
from .lapack import get_lapack_funcs
68
from ._misc import _datacopied
@@ -23,6 +25,7 @@ def safecall(f, name, *args, **kwargs):
2325
return ret[:-2]
2426

2527

28+
@_apply_over_batch(('a', 2))
2629
def qr(a, overwrite_a=False, lwork=None, mode='full', pivoting=False,
2730
check_finite=True):
2831
"""
@@ -366,6 +369,7 @@ def qr_multiply(a, c, mode='right', pivoting=False, conjugate=False,
366369
return (cQ,) + raw[1:]
367370

368371

372+
@_apply_over_batch(('a', 2))
369373
def rq(a, overwrite_a=False, lwork=None, mode='full', check_finite=True):
370374
"""
371375
Compute RQ decomposition of a matrix.

scipy/linalg/_decomp_schur.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from numpy import asarray_chkfinite, single, asarray, array
44
from numpy.linalg import norm
55

6-
6+
from scipy._lib._util import _apply_over_batch
77
# Local imports.
88
from ._misc import LinAlgError, _datacopied
99
from .lapack import get_lapack_funcs
@@ -14,6 +14,7 @@
1414
_double_precision = ['i', 'l', 'd']
1515

1616

17+
@_apply_over_batch(('a', 2))
1718
def schur(a, output='real', lwork=None, overwrite_a=False, sort=None,
1819
check_finite=True):
1920
"""

scipy/linalg/_decomp_svd.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@
22
import numpy as np
33
from numpy import zeros, r_, diag, dot, arccos, arcsin, where, clip
44

5+
from scipy._lib._util import _apply_over_batch
6+
57
# Local imports.
68
from ._misc import LinAlgError, _datacopied
79
from .lapack import get_lapack_funcs, _compute_lwork
810
from ._decomp import _asarray_validated
911

12+
1013
__all__ = ['svd', 'svdvals', 'diagsvd', 'orth', 'subspace_angles', 'null_space']
1114

1215

16+
@_apply_over_batch(('a', 2))
1317
def svd(a, full_matrices=True, compute_uv=True, overwrite_a=False,
1418
check_finite=True, lapack_driver='gesdd'):
1519
"""
@@ -249,6 +253,7 @@ def svdvals(a, overwrite_a=False, check_finite=True):
249253
check_finite=check_finite)
250254

251255

256+
@_apply_over_batch(('s', 1))
252257
def diagsvd(s, M, N):
253258
"""
254259
Construct the sigma matrix in SVD from singular values and size M, N.
@@ -301,6 +306,7 @@ def diagsvd(s, M, N):
301306

302307
# Orthonormal decomposition
303308

309+
@_apply_over_batch(('A', 2))
304310
def orth(A, rcond=None):
305311
"""
306312
Construct an orthonormal basis for the range of A using SVD
@@ -349,6 +355,7 @@ def orth(A, rcond=None):
349355
return Q
350356

351357

358+
@_apply_over_batch(('A', 2))
352359
def null_space(A, rcond=None, *, overwrite_a=False, check_finite=True,
353360
lapack_driver='gesdd'):
354361
"""

0 commit comments

Comments
 (0)