Skip to content

Commit e003d8c

Browse files
authored
ENH: linalg: add Python wrapper of ?gtcon (scipy#21328)
* ENH: linalg: add Python wrapper of ?gtcon * MAINT: linalg: wrap both signatures of gtcon; fix test * TST: linalg: accommodate old versions of NumPy
1 parent 2a15022 commit e003d8c

File tree

3 files changed

+84
-0
lines changed

3 files changed

+84
-0
lines changed

scipy/linalg/flapack_gen_tri.pyf.src

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,57 @@ subroutine <prefix>gttrs(trans, n, nrhs, dl, d, du, du2, ipiv, b, ldb, info)
6666
end subroutine gttrs
6767

6868

69+
subroutine <prefix2>gtcon(norm,n,dl,d,du,du2,ipiv,anorm,rcond,work,iwork,info)
70+
! ?GTCON estimates the reciprocal of the condition number of a real
71+
! tridiagonal matrix A using the LU factorization as computed by
72+
! ?GTTRF.
73+
! An estimate is obtained for norm(inv(A)), and the reciprocal of the
74+
! condition number is computed as RCOND = 1 / (ANORM * norm(inv(A))).
75+
threadsafe
76+
callstatement (*f2py_func)(norm,&n,dl,d,du,du2,ipiv,&anorm,&rcond,work,iwork,&info)
77+
callprotoargument char*, F_INT*, <ctype2>*, <ctype2>*, <ctype2>*, <ctype2>*, F_INT*, <ctype2>*, <ctype2>*, <ctype2>*, F_INT*, F_INT*
78+
79+
character optional, intent(in) :: norm = '1'
80+
integer intent(hide), depend(d) :: n = max(3, len(d))
81+
<ftype2> intent(in), depend(n), dimension(n - 1) :: dl
82+
<ftype2> intent(in), dimension(n) :: d
83+
<ftype2> intent(in), depend(n), dimension(n - 1) :: du
84+
<ftype2> intent(in), depend(n), dimension(n - 2) :: du2
85+
integer intent(in), depend(n), dimension(n) :: ipiv
86+
<ftype2> intent(in) :: anorm
87+
<ftype2> intent(out) :: rcond
88+
<ftype2> intent(hide, cache), dimension(2*n), depend(n) :: work
89+
integer intent(hide, cache), dimension(n), depend(n) :: iwork
90+
integer intent(out) :: info
91+
92+
end subroutine <prefix2>gtcon
93+
94+
95+
subroutine <prefix2c>gtcon(norm,n,dl,d,du,du2,ipiv,anorm,rcond,work,info)
96+
! ?GTCON estimates the reciprocal of the condition number of a real
97+
! tridiagonal matrix A using the LU factorization as computed by
98+
! ?GTTRF.
99+
! An estimate is obtained for norm(inv(A)), and the reciprocal of the
100+
! condition number is computed as RCOND = 1 / (ANORM * norm(inv(A))).
101+
threadsafe
102+
callstatement (*f2py_func)(norm,&n,dl,d,du,du2,ipiv,&anorm,&rcond,work,&info)
103+
callprotoargument char*, F_INT*, <ctype2c>*, <ctype2c>*, <ctype2c>*, <ctype2c>*, F_INT*, <ctype2>*, <ctype2>*, <ctype2c>*, F_INT*
104+
105+
character optional, intent(in) :: norm = '1'
106+
integer intent(hide), depend(d) :: n = max(3, len(d))
107+
<ftype2c> intent(in), depend(n), dimension(n - 1) :: dl
108+
<ftype2c> intent(in), dimension(n) :: d
109+
<ftype2c> intent(in), depend(n), dimension(n - 1) :: du
110+
<ftype2c> intent(in), depend(n), dimension(n - 2) :: du2
111+
integer intent(in), depend(n), dimension(n) :: ipiv
112+
<ftype2> intent(in) :: anorm
113+
<ftype2> intent(out) :: rcond
114+
<ftype2c> intent(hide, cache), dimension(2*n), depend(n) :: work
115+
integer intent(out) :: info
116+
117+
end subroutine <prefix2c>gtcon
118+
119+
69120
subroutine <prefix2>gtsvx(fact,trans,n,nrhs,dl,d,du,dlf,df,duf,du2,ipiv,b,ldb,x,ldx,rcond,ferr,berr,work,iwork,info)
70121
! ?GTSVX uses the LU factorization to compute the solution to a real
71122
! system of linear equations A * X = B or A**T * X = B,

scipy/linalg/lapack.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,11 @@
793793
cgttrs
794794
zgttrs
795795
796+
sgtcon
797+
dgtcon
798+
cgtcon
799+
zgtcon
800+
796801
stpqrt
797802
dtpqrt
798803
ctpqrt

scipy/linalg/tests/test_lapack.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2150,6 +2150,34 @@ def test_gttrf_gttrs_NAG_f07cdf_f07cef_f07crf_f07csf(du, d, dl, du_exp, d_exp,
21502150
assert_allclose(x_gttrs, x)
21512151

21522152

2153+
@pytest.mark.parametrize('dtype', DTYPES)
2154+
@pytest.mark.parametrize('norm', ['1', 'I', 'O'])
2155+
@pytest.mark.parametrize('n', [3, 10])
2156+
def test_gtcon(dtype, norm, n):
2157+
rng = np.random.default_rng(23498324)
2158+
2159+
d = rng.random(n) + rng.random(n)*1j
2160+
dl = rng.random(n - 1) + rng.random(n - 1)*1j
2161+
du = rng.random(n - 1) + rng.random(n - 1)*1j
2162+
A = np.diag(d) + np.diag(dl, -1) + np.diag(du, 1)
2163+
if np.issubdtype(dtype, np.floating):
2164+
A, d, dl, du = A.real, d.real, dl.real, du.real
2165+
A, d, dl, du = A.astype(dtype), d.astype(dtype), dl.astype(dtype), du.astype(dtype)
2166+
2167+
anorm = np.abs(A).sum(axis=0).max()
2168+
2169+
gttrf, gtcon = get_lapack_funcs(('gttrf', 'gtcon'), (A,))
2170+
dl, d, du, du2, ipiv, info = gttrf(dl, d, du)
2171+
res, _ = gtcon(dl, d, du, du2, ipiv, anorm, norm=norm)
2172+
2173+
gecon, getrf = get_lapack_funcs(('gecon', 'getrf'), (A,))
2174+
lu, ipvt, info = getrf(A)
2175+
ref, _ = gecon(lu, anorm, norm=norm)
2176+
2177+
rtol = np.finfo(dtype).eps**0.75
2178+
assert_allclose(res, ref, rtol=rtol)
2179+
2180+
21532181
@pytest.mark.parametrize('dtype', DTYPES)
21542182
@pytest.mark.parametrize('shape', [(3, 7), (7, 3), (2**18, 2**18)])
21552183
def test_geqrfp_lwork(dtype, shape):

0 commit comments

Comments
 (0)