Skip to content

Commit fe518f1

Browse files
authored
ENH: linalg: add batch support for matrix -> scalar funcs (scipy#22127)
* ENH: linalg.expm_cond: add batch support * DOC: linalg: inject batch support note automatically * ENH: linalg.issymmetric/ishermitian: add batch support * MAINT: linalg: adjustments to address CI failures
1 parent 92ed633 commit fe518f1

File tree

6 files changed

+75
-5
lines changed

6 files changed

+75
-5
lines changed

scipy/_lib/_util.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,6 +1176,14 @@ def _dict_formatter(d, n=0, mplus=1, sorter=None):
11761176
return s
11771177

11781178

1179+
_batch_note = """
1180+
The documentation is written assuming array arguments are of specified
1181+
"core" shapes. However, array argument(s) of this function may have additional
1182+
"batch" dimensions prepended to the core shape. In this case, the array is treated
1183+
as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
1184+
"""
1185+
1186+
11791187
def _apply_over_batch(*argdefs):
11801188
"""
11811189
Factory for decorator that applies a function over batched arguments.
@@ -1266,5 +1274,9 @@ def wrapper(*args, **kwargs):
12661274
# contributor to pass an `pack_result` callable to the decorator factory.
12671275
return results[0] if len(results) == 1 else results
12681276

1277+
doc = FunctionDoc(wrapper)
1278+
doc['Extended Summary'].append(_batch_note.rstrip())
1279+
wrapper.__doc__ = str(doc).split("\n", 1)[1] # remove signature
1280+
12691281
return wrapper
12701282
return decorator

scipy/linalg/_cythonized_array_utils.pyx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ from scipy.linalg._cythonized_array_utils cimport (
88
)
99
from scipy.linalg.cython_lapack cimport sgetrf, dgetrf, cgetrf, zgetrf
1010
from libc.stdlib cimport malloc, free
11+
from scipy._lib._util import _apply_over_batch
1112

1213
__all__ = ['bandwidth', 'issymmetric', 'ishermitian']
1314

@@ -237,6 +238,7 @@ cdef inline (int, int) band_check_internal_noncontig(const np_numeric_t[:, :]A)
237238
return lower_band, upper_band
238239

239240

241+
@_apply_over_batch(('a', 2))
240242
@cython.embedsignature(True)
241243
def issymmetric(a, atol=None, rtol=None):
242244
"""Check if a square 2D array is symmetric.
@@ -367,6 +369,7 @@ cdef inline bint is_sym_her_real_noncontig_internal(const np_numeric_t[:, :]A) n
367369
return True
368370

369371

372+
@_apply_over_batch(('a', 2))
370373
@cython.embedsignature(True)
371374
def ishermitian(a, atol=None, rtol=None):
372375
"""Check if a square 2D array is Hermitian.

scipy/linalg/_decomp.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,6 @@ def eig(a, b=None, left=False, right=True, overwrite_a=False,
124124
125125
where ``.H`` is the Hermitian conjugation.
126126
127-
The documentation is written assuming array arguments are of specified
128-
"core" shapes. However, array argument(s) of this function may have additional
129-
"batch" dimensions prepended to the core shape. In this case, the array is treated
130-
as a batch of lower-dimensional slices; see :ref:`linalg_batch` for details.
131-
132127
Parameters
133128
----------
134129
a : (M, M) array_like

scipy/linalg/_expm_frechet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Frechet derivative of the matrix exponential."""
22
import numpy as np
33
import scipy.linalg
4+
from scipy._lib._util import _apply_over_batch
5+
46

57
__all__ = ['expm_frechet', 'expm_cond']
68

@@ -351,6 +353,7 @@ def expm_frechet_kronform(A, method=None, check_finite=True):
351353
return np.vstack(cols).T
352354

353355

356+
@_apply_over_batch(('A', 2))
354357
def expm_cond(A, check_finite=True):
355358
"""
356359
Relative condition number of the matrix exponential in the Frobenius norm.

scipy/linalg/tests/meson.build

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
python_sources = [
22
'__init__.py',
33
'test_basic.py',
4+
'test_batch.py',
45
'test_blas.py',
56
'test_cython_blas.py',
67
'test_cython_lapack.py',

scipy/linalg/tests/test_batch.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import pytest
2+
import numpy as np
3+
from numpy.testing import assert_allclose
4+
from scipy import linalg
5+
6+
7+
real_floating = [np.float32, np.float64]
8+
complex_floating = [np.complex64, np.complex128]
9+
floating = real_floating + complex_floating
10+
11+
12+
def get_nearly_hermitian(shape, dtype, atol, rng):
13+
# Generate a batch of nearly Hermitian matrices with specified
14+
# `shape` and `dtype`. `atol` controls the level of noise in
15+
# Hermitian-ness to by generated by `rng`.
16+
A = rng.random(shape).astype(dtype)
17+
At = np.conj(A.swapaxes(-1, -2))
18+
noise = rng.standard_normal(size=A.shape).astype(dtype) * atol
19+
return A + At + noise
20+
21+
22+
class TestMatrixInScalarOut:
23+
24+
def batch_test(self, fun, args=(), kwargs=None, dtype=np.float64,
25+
batch_shape=(5, 3), core_shape=(4, 4), seed=8342310302941288912051):
26+
kwargs = {} if kwargs is None else kwargs
27+
rng = np.random.default_rng(seed)
28+
# test_expm_cond doesn't need symmetric/hermitian matrices, and
29+
# test_issymmetric doesn't need hermitian matrices, but it doesn't hurt.
30+
A = get_nearly_hermitian(batch_shape + core_shape, dtype, 3e-4, rng)
31+
32+
res = fun(A, *args, **kwargs)
33+
34+
for i in range(batch_shape[0]):
35+
for j in range(batch_shape[1]):
36+
ref = fun(A[i, j], *args, **kwargs)
37+
assert_allclose(res[i, j], ref)
38+
39+
return res
40+
41+
@pytest.mark.parametrize('dtype', floating)
42+
def test_expm_cond(self, dtype):
43+
self.batch_test(linalg.expm_cond, dtype=dtype)
44+
45+
@pytest.mark.parametrize('dtype', floating)
46+
def test_issymmetric(self, dtype):
47+
res = self.batch_test(linalg.issymmetric, dtype=dtype, kwargs=dict(atol=1e-3))
48+
assert not np.all(res) # ensure test is not trivial: not all True or False;
49+
assert np.any(res) # also confirms that `atol` is passed to issymmetric
50+
51+
@pytest.mark.parametrize('dtype', floating)
52+
def test_ishermitian(self, dtype):
53+
res = self.batch_test(linalg.issymmetric, dtype=dtype, kwargs=dict(atol=1e-3))
54+
assert not np.all(res) # ensure test is not trivial: not all True or False;
55+
assert np.any(res) # also confirms that `atol` is passed to ishermitian
56+

0 commit comments

Comments
 (0)