Skip to content

Commit 2575d77

Browse files
committed
TST: linalg.lapack.sy_hetrs: consolidate tests
1 parent fb4df3e commit 2575d77

File tree

2 files changed

+12
-28
lines changed

2 files changed

+12
-28
lines changed

scipy/linalg/flapack_sym_herm.pyf.src

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ end subroutine <prefix2c>heevd_lwork
317317
integer intent(hide), depend(b) :: nrhs = shape(b, 1)
318318
integer intent(hide), depend(a) :: lda = MAX(1, shape(a, 0))
319319
integer intent(hide), depend(b,n),check(ldb >= n) :: ldb = MAX(1, shape(b, 0))
320-
integer intent(out):: info
320+
integer intent(out) :: info
321321

322322
end subroutine <prefix>sytrs
323323

@@ -337,7 +337,7 @@ end subroutine <prefix2c>heevd_lwork
337337
integer intent(hide), depend(b) :: nrhs = shape(b, 1)
338338
integer intent(hide), depend(a) :: lda = MAX(1, shape(a, 0))
339339
integer intent(hide), depend(b,n),check(ldb >= n) :: ldb = MAX(1, shape(b, 0))
340-
integer intent(out):: info
340+
integer intent(out) :: info
341341

342342
end subroutine <prefix2c>hetrs
343343

scipy/linalg/tests/test_lapack.py

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3415,43 +3415,27 @@ def test_tgsyl(dtype, trans, ijob):
34153415
err_msg='lhs2 and rhs2 do not match')
34163416

34173417

3418+
@pytest.mark.parametrize('mtype', ['sy', 'he']) # matrix type
34183419
@pytest.mark.parametrize('dtype', DTYPES)
34193420
@pytest.mark.parametrize('lower', (0, 1))
3420-
def test_sytrs(dtype, lower):
3421+
def test_sy_hetrs(mtype, dtype, lower):
3422+
if mtype == 'he' and dtype in REAL_DTYPES:
3423+
pytest.skip("hetrs not for real dtypes.")
34213424
rng = np.random.default_rng(1723059677121834)
34223425
n, nrhs = 20, 5
34233426
if dtype in COMPLEX_DTYPES:
34243427
A = (rng.uniform(size=(n, n)) + rng.uniform(size=(n, n))*1j).astype(dtype)
34253428
else:
34263429
A = rng.uniform(size=(n, n)).astype(dtype)
34273430

3428-
A = A + A.T
3431+
A = A + A.T if mtype == 'sy' else A + A.conj().T
34293432
b = rng.uniform(size=(n, nrhs)).astype(dtype)
3430-
sytrf, sytrf_lwork, sytrs = get_lapack_funcs(['sytrf', 'sytrf_lwork', 'sytrs'],
3431-
dtype=dtype)
3432-
lwork = sytrf_lwork(n, lower=lower)
3433-
ldu, ipiv, info = sytrf(A, lwork=lwork)
3433+
names = f'{mtype}trf', f'{mtype}trf_lwork', f'{mtype}trs'
3434+
trf, trf_lwork, trs = get_lapack_funcs(names, dtype=dtype)
3435+
lwork = trf_lwork(n, lower=lower)
3436+
ldu, ipiv, info = trf(A, lwork=lwork)
34343437
assert info == 0
3435-
x, info = sytrs(a=ldu, ipiv=ipiv, b=b)
3438+
x, info = trs(a=ldu, ipiv=ipiv, b=b)
34363439
assert info == 0
34373440
eps = np.finfo(dtype).eps
34383441
assert_allclose(A@x, b, atol=100*n*eps)
3439-
3440-
3441-
@pytest.mark.parametrize('dtype', COMPLEX_DTYPES)
3442-
@pytest.mark.parametrize('lower', (0, 1))
3443-
def test_hetrs(dtype, lower):
3444-
rng = np.random.default_rng(1723059677121834)
3445-
n, nrhs = 20, 5
3446-
A = (rng.uniform(size=(n, n)) + rng.uniform(size=(n, n))*1j).astype(dtype)
3447-
A = A + A.conj().T
3448-
b = np.random.rand(n, nrhs).astype(dtype)
3449-
hetrf, hetrf_lwork, hetrs = get_lapack_funcs(['hetrf', 'hetrf_lwork', 'hetrs'],
3450-
dtype=dtype)
3451-
lwork = hetrf_lwork(n, lower=lower)
3452-
ldu, ipiv, info = hetrf(A, lwork=lwork)
3453-
assert info == 0
3454-
x, info = hetrs(a=ldu, ipiv=ipiv, b=b)
3455-
assert info == 0
3456-
eps = np.finfo(dtype).eps
3457-
assert_allclose(A @ x, b, atol=100*n*eps)

0 commit comments

Comments
 (0)