Skip to content

Commit d2545a6

Browse files
authored
Merge pull request scipy#21336 from ilayn/sy_hetrs_wrapper
ENH: linalg: Add `sy/hetrs` LAPACK wrappers
2 parents 1d50e6e + 2575d77 commit d2545a6

File tree

3 files changed

+74
-0
lines changed

3 files changed

+74
-0
lines changed

scipy/linalg/flapack_sym_herm.pyf.src

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,46 @@ end subroutine <prefix2c>heevd_lwork
302302
end subroutine <prefix>sytrf_lwork
303303

304304

305+
subroutine <prefix>sytrs(n,nrhs,a,lda,ipiv,b,ldb,info,lower)
306+
307+
! Solve A * X = B for symmetric A matrix after calling ?sytrf
308+
309+
callstatement (*f2py_func)((lower?"L":"U"),&n,&nrhs,a,&lda,ipiv,b,&ldb,&info)
310+
callprotoargument char*,F_INT*,F_INT*,<ctype>*,F_INT*,F_INT*,<ctype>*,F_INT*,F_INT*
311+
312+
<ftype> intent(in), dimension(lda, n), check((lda >= n) && (n >= 0)) :: a
313+
<ftype> intent(in,out,copy,out=x),dimension(ldb, nrhs) :: b
314+
integer optional,intent(in),check(lower==0||lower==1) :: lower = 0
315+
integer intent(in),dimension(n),depend(n) :: ipiv
316+
integer intent(hide), depend(a),check(n >= 0) :: n = shape(a, 1)
317+
integer intent(hide), depend(b) :: nrhs = shape(b, 1)
318+
integer intent(hide), depend(a) :: lda = MAX(1, shape(a, 0))
319+
integer intent(hide), depend(b,n),check(ldb >= n) :: ldb = MAX(1, shape(b, 0))
320+
integer intent(out) :: info
321+
322+
end subroutine <prefix>sytrs
323+
324+
325+
subroutine <prefix2c>hetrs(n,nrhs,a,lda,ipiv,b,ldb,info,lower)
326+
327+
! Solve A * X = B for hermitian A matrix after calling ?hetrf
328+
329+
callstatement (*f2py_func)((lower?"L":"U"),&n,&nrhs,a,&lda,ipiv,b,&ldb,&info)
330+
callprotoargument char*,F_INT*,F_INT*,<ctype2c>*,F_INT*,F_INT*,<ctype2c>*,F_INT*,F_INT*
331+
332+
<ftype2c> intent(in), dimension(lda, n), check((lda >= n) && (n >= 0)) :: a
333+
<ftype2c> intent(in,out,copy,out=x),dimension(ldb, nrhs) :: b
334+
integer optional,intent(in),check(lower==0||lower==1) :: lower = 0
335+
integer intent(in),dimension(n),depend(n) :: ipiv
336+
integer intent(hide), depend(a),check(n >= 0) :: n = shape(a, 1)
337+
integer intent(hide), depend(b) :: nrhs = shape(b, 1)
338+
integer intent(hide), depend(a) :: lda = MAX(1, shape(a, 0))
339+
integer intent(hide), depend(b,n),check(ldb >= n) :: ldb = MAX(1, shape(b, 0))
340+
integer intent(out) :: info
341+
342+
end subroutine <prefix2c>hetrs
343+
344+
305345
subroutine <prefix>sysv(n,nrhs,a,lda,ipiv,b,ldb,work,lwork,info,lower)
306346

307347
! Solve A * X = B for symmetric A matrix

scipy/linalg/lapack.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@
345345
chetrf_lwork
346346
zhetrf_lwork
347347
348+
chetrs
349+
zhetrs
350+
348351
chfrk
349352
zhfrk
350353
@@ -660,6 +663,11 @@
660663
csytrf_lwork
661664
zsytrf_lwork
662665
666+
ssytrs
667+
dsytrs
668+
csytrs
669+
zsytrs
670+
663671
stbtrs
664672
dtbtrs
665673
ctbtrs

scipy/linalg/tests/test_lapack.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3413,3 +3413,29 @@ def test_tgsyl(dtype, trans, ijob):
34133413
err_msg='lhs1 and rhs1 do not match')
34143414
assert_allclose(lhs2, rhs2, atol=atol, rtol=0.,
34153415
err_msg='lhs2 and rhs2 do not match')
3416+
3417+
3418+
@pytest.mark.parametrize('mtype', ['sy', 'he']) # matrix type
3419+
@pytest.mark.parametrize('dtype', DTYPES)
3420+
@pytest.mark.parametrize('lower', (0, 1))
3421+
def test_sy_hetrs(mtype, dtype, lower):
3422+
if mtype == 'he' and dtype in REAL_DTYPES:
3423+
pytest.skip("hetrs not for real dtypes.")
3424+
rng = np.random.default_rng(1723059677121834)
3425+
n, nrhs = 20, 5
3426+
if dtype in COMPLEX_DTYPES:
3427+
A = (rng.uniform(size=(n, n)) + rng.uniform(size=(n, n))*1j).astype(dtype)
3428+
else:
3429+
A = rng.uniform(size=(n, n)).astype(dtype)
3430+
3431+
A = A + A.T if mtype == 'sy' else A + A.conj().T
3432+
b = rng.uniform(size=(n, nrhs)).astype(dtype)
3433+
names = f'{mtype}trf', f'{mtype}trf_lwork', f'{mtype}trs'
3434+
trf, trf_lwork, trs = get_lapack_funcs(names, dtype=dtype)
3435+
lwork = trf_lwork(n, lower=lower)
3436+
ldu, ipiv, info = trf(A, lwork=lwork)
3437+
assert info == 0
3438+
x, info = trs(a=ldu, ipiv=ipiv, b=b)
3439+
assert info == 0
3440+
eps = np.finfo(dtype).eps
3441+
assert_allclose(A@x, b, atol=100*n*eps)

0 commit comments

Comments
 (0)