Skip to content

Commit db2b642

Browse files
committed
Improce test coverage for LAPACK wrappers
1 parent 63ac2d2 commit db2b642

File tree

3 files changed

+50
-14
lines changed

3 files changed

+50
-14
lines changed

src/lapack.jl

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
module LAPACK2
22

33
using Base.LinAlg: BlasInt, chkstride1, LAPACKException
4+
using Base.LinAlg.BLAS: @blasfunc
45
using Base.LinAlg.LAPACK: chkuplo
56

67
# LAPACK wrappers
78
import Base.BLAS.@blasfunc
89

910
## Standard QR/QL
10-
function steqr!(compz::Char, d::StridedVector{Float64}, e::StridedVector{Float64},
11-
Z::StridedMatrix{Float64}, work::StridedVector{Float64} = compz == 'N' ? Vector{Float64}(0) :
12-
Vector{Float64}(max(1, 2n - 2)))
11+
function steqr!(compz::Char,
12+
d::StridedVector{Float64},
13+
e::StridedVector{Float64},
14+
Z::StridedMatrix{Float64} = compz == 'N' ? Matrix{Float64}(0,0) : Matrix{Float64}(length(d), length(d)),
15+
work::StridedVector{Float64} = compz == 'N' ? Vector{Float64}(0) : Vector{Float64}(max(1, 2*length(d) - 2)))
1316

1417
# Extract sizes
1518
n = length(d)
@@ -29,7 +32,7 @@ module LAPACK2
2932
# Allocations
3033
info = Vector{BlasInt}(1)
3134

32-
ccall(("dsteqr_", Base.liblapack_name),Void,
35+
ccall((@blasfunc("dsteqr_"), Base.liblapack_name),Void,
3336
(Ptr{UInt8}, Ptr{BlasInt}, Ptr{Float64}, Ptr{Float64},
3437
Ptr{Float64}, Ptr{BlasInt}, Ptr{Float64}, Ptr{BlasInt}),
3538
&compz, &n, d, e,
@@ -51,7 +54,7 @@ module LAPACK2
5154
# Allocations
5255
info = BlasInt[0]
5356

54-
ccall((:dsterf_, Base.liblapack_name), Void,
57+
ccall((@blasfunc("dsterf_"), Base.liblapack_name), Void,
5558
(Ptr{BlasInt}, Ptr{Float64}, Ptr{Float64}, Ptr{BlasInt}),
5659
&n, d, e, info)
5760

@@ -61,7 +64,14 @@ module LAPACK2
6164
end
6265

6366
## Divide and Conquer
64-
function stedc!(compz::Char, d::StridedVector{Float64}, e::StridedVector{Float64}, Z::StridedMatrix{Float64}, work::StridedVector{Float64}, lwork::BlasInt, iwork::StridedVector{BlasInt}, liwork::BlasInt)
67+
function stedc!(compz::Char,
68+
d::StridedVector{Float64},
69+
e::StridedVector{Float64},
70+
Z::StridedMatrix{Float64},
71+
work::StridedVector{Float64},
72+
lwork::BlasInt,
73+
iwork::StridedVector{BlasInt},
74+
liwork::BlasInt)
6575

6676
# Extract sizes
6777
n = length(d)
@@ -75,7 +85,7 @@ module LAPACK2
7585
# Allocations
7686
info = BlasInt[0]
7787

78-
ccall((@blasfunc(:dstedc_), Base.liblapack_name), Void,
88+
ccall((@blasfunc("dstedc_"), Base.liblapack_name), Void,
7989
(Ptr{UInt8}, Ptr{BlasInt}, Ptr{Float64}, Ptr{Float64},
8090
Ptr{Float64}, Ptr{BlasInt}, Ptr{Float64}, Ptr{BlasInt},
8191
Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
@@ -88,7 +98,10 @@ module LAPACK2
8898
return d, Z
8999
end
90100

91-
function stedc!(compz::Char, d::StridedVector{Float64}, e::StridedVector{Float64}, Z::StridedMatrix{Float64})
101+
function stedc!(compz::Char,
102+
d::StridedVector{Float64},
103+
e::StridedVector{Float64},
104+
Z::StridedMatrix{Float64} = compz == 'N' ? Matrix{Float64}(0,0) : Matrix{Float64}(length(d), length(d)))
92105

93106
work::Vector{Float64} = Float64[0]
94107
iwork::Vector{BlasInt} = BlasInt[0]
@@ -105,7 +118,22 @@ module LAPACK2
105118
## RRR
106119
for (lsymb, elty) in ((:dstemr_, :Float64), (:sstemr_, :Float32))
107120
@eval begin
108-
function stemr!(jobz::Char, range::Char, dv::StridedVector{$elty}, ev::StridedVector{$elty}, vl::$elty, vu::$elty, il::BlasInt, iu::BlasInt, w::StridedVector{$elty}, Z::StridedMatrix{$elty}, nzc::BlasInt, isuppz::StridedVector{BlasInt}, work::StridedVector{$elty}, lwork::BlasInt, iwork::StridedVector{BlasInt}, liwork::BlasInt)
121+
function stemr!(jobz::Char,
122+
range::Char,
123+
dv::StridedVector{$elty},
124+
ev::StridedVector{$elty},
125+
vl::$elty,
126+
vu::$elty,
127+
il::BlasInt,
128+
iu::BlasInt,
129+
w::StridedVector{$elty},
130+
Z::StridedMatrix{$elty},
131+
nzc::BlasInt,
132+
isuppz::StridedVector{BlasInt},
133+
work::StridedVector{$elty},
134+
lwork::BlasInt,
135+
iwork::StridedVector{BlasInt},
136+
liwork::BlasInt)
109137

110138
# Extract sizes
111139
n = length(dv)
@@ -115,7 +143,7 @@ module LAPACK2
115143
length(ev) >= n - 1 || throw(DimensionMismatch("subdiagonal is too short"))
116144

117145
# Allocations
118-
eev::Vector{$elty} = length(ev) == n - 1 ? [ev, zero($elty)] : copy(ev)
146+
eev::Vector{$elty} = length(ev) == n - 1 ? [ev; zero($elty)] : copy(ev)
119147
abstol = Vector{$elty}(1)
120148
m = Vector{BlasInt}(1)
121149
tryrac = BlasInt[1]
@@ -130,7 +158,7 @@ module LAPACK2
130158
&jobz, &range, &n, dv,
131159
eev, &vl, &vu, &il,
132160
&iu, m, w, Z,
133-
&ldz, &nzc, isuppz, tryrac,
161+
&max(1, ldz), &nzc, isuppz, tryrac,
134162
work, &lwork, iwork, &liwork,
135163
info)
136164

@@ -139,7 +167,15 @@ module LAPACK2
139167
w, Z, tryrac[1]
140168
end
141169

142-
function stemr!(jobz::Char, range::Char, dv::StridedVector{$elty}, ev::StridedVector{$elty}, vl::$elty = typemin($elty), vu::$elty = typemax($elty), il::BlasInt = 1, iu::BlasInt = length(dv))
170+
function stemr!(jobz::Char,
171+
range::Char,
172+
dv::StridedVector{$elty},
173+
ev::StridedVector{$elty},
174+
vl::$elty = typemin($elty),
175+
vu::$elty = typemax($elty),
176+
il::BlasInt = 1,
177+
iu::BlasInt = length(dv))
178+
143179
n = length(dv)
144180
w = Vector{$elty}(n)
145181
if jobz == 'N'
@@ -157,7 +193,7 @@ module LAPACK2
157193
lwork::BlasInt = -1
158194
iwork = BlasInt[0]
159195
liwork::BlasInt = -1
160-
Z = Matrix{$elty}(1, 1)
196+
Z = Matrix{$elty}(jobz == 'N' ? 1 : n, 1)
161197
nzc = -1
162198

163199
stemr!(jobz, range, dv, ev, vl, vu, il, iu, w, Z, nzc, isuppz, work, lwork, iwork, liwork)

test/qr.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using Base.Test
2-
using Base.LAPACK
32
using LinearAlgebra
43
using LinearAlgebra.QRModule.qrBlocked!
54

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ using Base.Test
88
include("tridiag.jl")
99
include("svd.jl")
1010
include("rectfullpacked.jl")
11+
include("lapack.jl")
1112
# end

0 commit comments

Comments
 (0)