Skip to content

Commit fb4df3e

Browse files
committed
ENH: linalg: Add sy/hetrs LAPACK wrappers
1 parent 94532e7 commit fb4df3e

File tree

3 files changed

+90
-0
lines changed

3 files changed

+90
-0
lines changed

scipy/linalg/flapack_sym_herm.pyf.src

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,46 @@ end subroutine <prefix2c>heevd_lwork
302302
end subroutine <prefix>sytrf_lwork
303303

304304

305+
subroutine <prefix>sytrs(n,nrhs,a,lda,ipiv,b,ldb,info,lower)
306+
307+
! Solve A * X = B for symmetric A matrix after calling ?sytrf
308+
309+
callstatement (*f2py_func)((lower?"L":"U"),&n,&nrhs,a,&lda,ipiv,b,&ldb,&info)
310+
callprotoargument char*,F_INT*,F_INT*,<ctype>*,F_INT*,F_INT*,<ctype>*,F_INT*,F_INT*
311+
312+
<ftype> intent(in), dimension(lda, n), check((lda >= n) && (n >= 0)) :: a
313+
<ftype> intent(in,out,copy,out=x),dimension(ldb, nrhs) :: b
314+
integer optional,intent(in),check(lower==0||lower==1) :: lower = 0
315+
integer intent(in),dimension(n),depend(n) :: ipiv
316+
integer intent(hide), depend(a),check(n >= 0) :: n = shape(a, 1)
317+
integer intent(hide), depend(b) :: nrhs = shape(b, 1)
318+
integer intent(hide), depend(a) :: lda = MAX(1, shape(a, 0))
319+
integer intent(hide), depend(b,n),check(ldb >= n) :: ldb = MAX(1, shape(b, 0))
320+
integer intent(out):: info
321+
322+
end subroutine <prefix>sytrs
323+
324+
325+
subroutine <prefix2c>hetrs(n,nrhs,a,lda,ipiv,b,ldb,info,lower)
326+
327+
! Solve A * X = B for hermitian A matrix after calling ?hetrf
328+
329+
callstatement (*f2py_func)((lower?"L":"U"),&n,&nrhs,a,&lda,ipiv,b,&ldb,&info)
330+
callprotoargument char*,F_INT*,F_INT*,<ctype2c>*,F_INT*,F_INT*,<ctype2c>*,F_INT*,F_INT*
331+
332+
<ftype2c> intent(in), dimension(lda, n), check((lda >= n) && (n >= 0)) :: a
333+
<ftype2c> intent(in,out,copy,out=x),dimension(ldb, nrhs) :: b
334+
integer optional,intent(in),check(lower==0||lower==1) :: lower = 0
335+
integer intent(in),dimension(n),depend(n) :: ipiv
336+
integer intent(hide), depend(a),check(n >= 0) :: n = shape(a, 1)
337+
integer intent(hide), depend(b) :: nrhs = shape(b, 1)
338+
integer intent(hide), depend(a) :: lda = MAX(1, shape(a, 0))
339+
integer intent(hide), depend(b,n),check(ldb >= n) :: ldb = MAX(1, shape(b, 0))
340+
integer intent(out):: info
341+
342+
end subroutine <prefix2c>hetrs
343+
344+
305345
subroutine <prefix>sysv(n,nrhs,a,lda,ipiv,b,ldb,work,lwork,info,lower)
306346

307347
! Solve A * X = B for symmetric A matrix

scipy/linalg/lapack.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@
345345
chetrf_lwork
346346
zhetrf_lwork
347347
348+
chetrs
349+
zhetrs
350+
348351
chfrk
349352
zhfrk
350353
@@ -660,6 +663,11 @@
660663
csytrf_lwork
661664
zsytrf_lwork
662665
666+
ssytrs
667+
dsytrs
668+
csytrs
669+
zsytrs
670+
663671
stbtrs
664672
dtbtrs
665673
ctbtrs

scipy/linalg/tests/test_lapack.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3413,3 +3413,45 @@ def test_tgsyl(dtype, trans, ijob):
34133413
err_msg='lhs1 and rhs1 do not match')
34143414
assert_allclose(lhs2, rhs2, atol=atol, rtol=0.,
34153415
err_msg='lhs2 and rhs2 do not match')
3416+
3417+
3418+
@pytest.mark.parametrize('dtype', DTYPES)
3419+
@pytest.mark.parametrize('lower', (0, 1))
3420+
def test_sytrs(dtype, lower):
3421+
rng = np.random.default_rng(1723059677121834)
3422+
n, nrhs = 20, 5
3423+
if dtype in COMPLEX_DTYPES:
3424+
A = (rng.uniform(size=(n, n)) + rng.uniform(size=(n, n))*1j).astype(dtype)
3425+
else:
3426+
A = rng.uniform(size=(n, n)).astype(dtype)
3427+
3428+
A = A + A.T
3429+
b = rng.uniform(size=(n, nrhs)).astype(dtype)
3430+
sytrf, sytrf_lwork, sytrs = get_lapack_funcs(['sytrf', 'sytrf_lwork', 'sytrs'],
3431+
dtype=dtype)
3432+
lwork = sytrf_lwork(n, lower=lower)
3433+
ldu, ipiv, info = sytrf(A, lwork=lwork)
3434+
assert info == 0
3435+
x, info = sytrs(a=ldu, ipiv=ipiv, b=b)
3436+
assert info == 0
3437+
eps = np.finfo(dtype).eps
3438+
assert_allclose(A@x, b, atol=100*n*eps)
3439+
3440+
3441+
@pytest.mark.parametrize('dtype', COMPLEX_DTYPES)
3442+
@pytest.mark.parametrize('lower', (0, 1))
3443+
def test_hetrs(dtype, lower):
3444+
rng = np.random.default_rng(1723059677121834)
3445+
n, nrhs = 20, 5
3446+
A = (rng.uniform(size=(n, n)) + rng.uniform(size=(n, n))*1j).astype(dtype)
3447+
A = A + A.conj().T
3448+
b = np.random.rand(n, nrhs).astype(dtype)
3449+
hetrf, hetrf_lwork, hetrs = get_lapack_funcs(['hetrf', 'hetrf_lwork', 'hetrs'],
3450+
dtype=dtype)
3451+
lwork = hetrf_lwork(n, lower=lower)
3452+
ldu, ipiv, info = hetrf(A, lwork=lwork)
3453+
assert info == 0
3454+
x, info = hetrs(a=ldu, ipiv=ipiv, b=b)
3455+
assert info == 0
3456+
eps = np.finfo(dtype).eps
3457+
assert_allclose(A @ x, b, atol=100*n*eps)

0 commit comments

Comments
 (0)