Skip to content

Commit f765df3

Browse files
authored
Merge pull request #117 from JuliaLinearAlgebra/an/morecoverage
Fix and test pteqr
2 parents cac315b + d4f43ae commit f765df3

File tree

3 files changed

+46
-14
lines changed

3 files changed

+46
-14
lines changed

src/lapack.jl

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -425,31 +425,39 @@ function lahqr!(H::StridedMatrix{Float64})
425425
end
426426

427427
## Cholesky + singular values
428-
function spteqr!(
428+
function pteqr!(
429429
compz::Char,
430430
d::StridedVector{Float64},
431431
e::StridedVector{Float64},
432432
Z::StridedMatrix{Float64},
433-
work::StridedVector{Float64} = Vector{Float64}(undef, 4length(d)),
433+
work::StridedVector{Float64},
434434
)
435-
436435
n = length(d)
437436
ldz = stride(Z, 2)
438437

439438
# Checks
440-
length(e) >= n - 1 || throw(DimensionMismatch("subdiagonal vector is too short"))
439+
if length(e) < n - 1
440+
throw(DimensionMismatch("subdiagonal vector is too short"))
441+
end
442+
chkstride1(d)
443+
chkstride1(e)
444+
chkstride1(Z)
441445
if compz == 'N'
442-
elseif compz == in('V', 'I')
443-
size(Z, 1) >= n || throw(DimensionMismatch("Z does not have enough rows"))
444-
size(Z, 2) >= ldz || throw(DimensionMismatch("Z does not have enough columns"))
446+
elseif compz ('V', 'I')
447+
if size(Z, 1) < n
448+
throw(DimensionMismatch("Z does not have enough rows"))
449+
end
450+
if size(Z, 2) < ldz
451+
throw(DimensionMismatch("Z does not have enough columns"))
452+
end
445453
else
446454
throw(ArgumentError("compz must be either 'N', 'V', or 'I'"))
447455
end
448456

449-
info = Vector{BlasInt}(undef, 1)
457+
info = Ref{BlasInt}(1)
450458

451459
ccall(
452-
(@blasfunc(:dpteqr_), liblapack_name),
460+
(@blasfunc(dpteqr_), liblapack_name),
453461
Cvoid,
454462
(
455463
Ref{UInt8},
@@ -459,7 +467,7 @@ function spteqr!(
459467
Ptr{Float64},
460468
Ref{BlasInt},
461469
Ptr{Float64},
462-
Ptr{BlasInt},
470+
Ref{BlasInt},
463471
),
464472
compz,
465473
n,
@@ -471,9 +479,23 @@ function spteqr!(
471479
info,
472480
)
473481

474-
info[1] == 0 || throw(LAPACKException(info[1]))
482+
if info[] != 0
483+
throw(LAPACKException(info[]))
484+
end
485+
486+
return d, Z
487+
end
475488

476-
d, Z
489+
function pteqr!(compz::Char, d::StridedVector{Float64}, e::StridedVector{Float64})
490+
n = length(d)
491+
492+
Z = if compz == 'I'
493+
Matrix{Float64}(undef, n, n)
494+
else
495+
Matrix{Float64}(undef, 0, 0)
496+
end
497+
work = Vector{Float64}(undef, 4 * n)
498+
return pteqr!(compz, d, e, Z, work)
477499
end
478500

479501
# Gu's dnc eigensolver

test/lapack.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ using GenericLinearAlgebra.LAPACK2
9191
(Float32, Float64),
9292
side in ('L', 'R', 'B'),
9393
howmny in ('A', 'S')
94-
#='B', =#
94+
9595
select = ones(Int, n)
9696
S, P = triu(randn(eltype, n, n)), triu(randn(eltype, n, n))
9797
VL, VR, m = LAPACK2.tgevc!(side, howmny, select, copy(S), copy(P))
@@ -106,4 +106,14 @@ using GenericLinearAlgebra.LAPACK2
106106
sqrt(eps(eltype))
107107
end
108108
end
109+
110+
@testset "pteqr" begin
111+
d = fill(10.0, n)
112+
e = fill(1.0, n - 1)
113+
vals, vecs = LAPACK2.pteqr!('I', copy(d), copy(e))
114+
@test SymTridiagonal(d, e) vecs * Diagonal(vals) * vecs'
115+
116+
vals2, _ = LAPACK2.pteqr!('N', copy(d), copy(e))
117+
@test vals vals2
118+
end
109119
end

test/svd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,6 @@ using Test, GenericLinearAlgebra, LinearAlgebra, Quaternions, DoubleFloats
239239
@testset "Generic HessenbergQ multiplication" begin
240240
A = big.(randn(10, 10))
241241
BF = GenericLinearAlgebra.bidiagonalize!(copy(A))
242-
@test (BF.rightQ'*Matrix(I, size(A)...))*BF.rightQ I
242+
@test (BF.rightQ' * Matrix(I, size(A)...)) * BF.rightQ I
243243
end
244244
end

0 commit comments

Comments
 (0)