Skip to content

Commit 6a5e348

Browse files
authored
ENH: linalg: add batch support for remaining cholesky functions (scipy#22157)
* ENH: linalg.cho_factor: add batch support * ENH: linalg.cholesky_banded: add batch support
1 parent a755ee7 commit 6a5e348

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

scipy/linalg/_decomp_cholesky.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
105105
return c
106106

107107

108+
@_apply_over_batch(("a", 2))
108109
def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
109110
"""
110111
Compute the Cholesky decomposition of a matrix, to use in cho_solve
@@ -246,6 +247,7 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
246247
return x
247248

248249

250+
@_apply_over_batch(("ab", 2))
249251
def cholesky_banded(ab, overwrite_ab=False, lower=False, check_finite=True):
250252
"""
251253
Cholesky decompose a banded Hermitian positive-definite matrix

scipy/linalg/tests/test_batch.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ def test_bandwidth(self, dtype, rng):
160160
A = np.asarray([np.triu(A, k) for k in range(-3, 3)]).reshape((2, 3, 4, 4))
161161
self.batch_test(linalg.bandwidth, A, n_out=2)
162162

163-
@pytest.mark.parametrize('fun_n_out', [(linalg.cholesky, 1), (linalg.ldl, 3)])
163+
@pytest.mark.parametrize('fun_n_out', [(linalg.cholesky, 1), (linalg.ldl, 3),
164+
(linalg.cho_factor, 2)])
164165
@pytest.mark.parametrize('dtype', floating)
165166
def test_ldl_cholesky(self, fun_n_out, dtype, rng):
166167
fun, n_out = fun_n_out
@@ -286,3 +287,10 @@ def test_rsf2cs(self, dtype, rng):
286287
A = get_random((2, 3, 4, 4), dtype=dtype, rng=rng)
287288
T, Z = linalg.schur(A)
288289
self.batch_test(linalg.rsf2csf, (T, Z), n_out=2)
290+
291+
@pytest.mark.parametrize('dtype', floating)
292+
def test_cholesky_banded(self, dtype, rng):
293+
ab = get_random((5, 4, 3, 6), dtype=dtype, rng=rng)
294+
ab[..., 0, 0] = 0
295+
ab[..., -1, :] = 10 # make diagonal dominant
296+
self.batch_test(linalg.cholesky_banded, ab)

0 commit comments

Comments
 (0)