Skip to content

Commit 1600a37

Browse files
authored
ENH: linalg: add batch support to remaining eigenvalue functions (scipy#22165)
* ENH: linalg.eigh_tridiagonal: add batch support * ENH: linalg.eigvalsh_tridiagonal: add batch support * ENH: linalg.eigvals: add batch support * TST: linalg.cdf2rdf: add test of batch support * Revert "TST: linalg.cdf2rdf: add test of batch support" This reverts commit 40bdaf3. * TST: linalg: generalize name of TestBatch
1 parent e117396 commit 1600a37

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

scipy/linalg/_decomp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,7 @@ def eig_banded(a_band, lower=False, eigvals_only=False, overwrite_a_band=False,
838838
return w, v
839839

840840

841+
@_apply_over_batch(('a', 2), ('b', 2))
841842
def eigvals(a, b=None, overwrite_a=False, check_finite=True,
842843
homogeneous_eigvals=False):
843844
"""
@@ -1126,6 +1127,7 @@ def eigvals_banded(a_band, lower=False, overwrite_a_band=False,
11261127
select_range=select_range, check_finite=check_finite)
11271128

11281129

1130+
@_apply_over_batch(('d', 1), ('e', 1))
11291131
def eigvalsh_tridiagonal(d, e, select='a', select_range=None,
11301132
check_finite=True, tol=0., lapack_driver='auto'):
11311133
"""
@@ -1207,6 +1209,7 @@ def eigvalsh_tridiagonal(d, e, select='a', select_range=None,
12071209
check_finite=check_finite, tol=tol, lapack_driver=lapack_driver)
12081210

12091211

1212+
@_apply_over_batch(('d', 1), ('e', 1))
12101213
def eigh_tridiagonal(d, e, eigvals_only=False, select='a', select_range=None,
12111214
check_finite=True, tol=0., lapack_driver='auto'):
12121215
"""

scipy/linalg/tests/test_batch.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@ def get_nearly_hermitian(shape, dtype, atol, rng):
2626
return A + At + noise
2727

2828

29-
class TestOneArrayIn:
30-
# Test the functions that accept one array argument
29+
class TestBatch:
30+
# Test batch support for most linalg functions
3131

32-
def batch_test(self, fun, arrays, core_dim=2, n_out=1, kwargs=None, dtype=None):
32+
def batch_test(self, fun, arrays, *, core_dim=2, n_out=1, kwargs=None, dtype=None,
33+
broadcast=True):
3334
# Check that all outputs of batched call `fun(A, **kwargs)` are the same
3435
# as if we loop over the separate vectors/matrices in `A`. Also check
3536
# that `fun` accepts `A` by position or keyword and that results are
@@ -53,7 +54,8 @@ def batch_test(self, fun, arrays, core_dim=2, n_out=1, kwargs=None, dtype=None):
5354
res = (res2,) if n_out == 1 else res2
5455
# This is not the general behavior (only batch dimensions get
5556
# broadcasted by the decorator) but it's easier for testing.
56-
arrays = np.broadcast_arrays(*arrays)
57+
if broadcast:
58+
arrays = np.broadcast_arrays(*arrays)
5759
batch_shape = arrays[0].shape[:-core_dim]
5860
for i in range(batch_shape[0]):
5961
for j in range(batch_shape[1]):
@@ -210,7 +212,7 @@ def test_eigvals_banded(self, dtype, rng):
210212

211213
@pytest.mark.parametrize('two_in', [False, True])
212214
@pytest.mark.parametrize('fun_n_nout', [(linalg.eigh, 1), (linalg.eigh, 2),
213-
(linalg.eigvalsh, 1)])
215+
(linalg.eigvalsh, 1), (linalg.eigvals, 1)])
214216
@pytest.mark.parametrize('dtype', floating)
215217
def test_eigh(self, two_in, fun_n_nout, dtype, rng):
216218
fun, n_out = fun_n_nout
@@ -291,6 +293,15 @@ def test_rsf2cs(self, dtype, rng):
291293
@pytest.mark.parametrize('dtype', floating)
292294
def test_cholesky_banded(self, dtype, rng):
293295
ab = get_random((5, 4, 3, 6), dtype=dtype, rng=rng)
294-
ab[..., 0, 0] = 0
295296
ab[..., -1, :] = 10 # make diagonal dominant
296297
self.batch_test(linalg.cholesky_banded, ab)
298+
299+
@pytest.mark.parametrize('fun_n_out', [(linalg.eigh_tridiagonal, 2),
300+
(linalg.eigvalsh_tridiagonal, 1)])
301+
@pytest.mark.parametrize('dtype', real_floating)
302+
# "Only real arrays currently supported"
303+
def test_eigh_tridiagonal(self, fun_n_out, dtype, rng):
304+
fun, n_out = fun_n_out
305+
d = get_random((3, 4, 5), dtype=dtype, rng=rng)
306+
e = get_random((3, 4, 4), dtype=dtype, rng=rng)
307+
self.batch_test(fun, (d, e), core_dim=1, n_out=n_out, broadcast=False)

0 commit comments

Comments
 (0)