Skip to content

Commit 20a7736

Browse files
authored
Merge pull request #113 from JuliaLinearAlgebra/an/testlapack
Improve some LAPACK test coverage
2 parents ef815d6 + 0c3d3e1 commit 20a7736

File tree

2 files changed

+64
-33
lines changed

2 files changed

+64
-33
lines changed

src/lapack.jl

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -483,12 +483,12 @@ for (f, elty) in ((:dsyevd_, :Float64), (:ssyevd_, :Float32))
483483
function syevd!(jobz::Char, uplo::Char, A::StridedMatrix{$elty})
484484
n = LinearAlgebra.checksquare(A)
485485
lda = stride(A, 2)
486-
w = Vector{$elty}(n)
486+
w = Vector{$elty}(undef, n)
487487
work = Vector{$elty}(undef, 1)
488488
lwork = BlasInt(-1)
489489
iwork = Vector{BlasInt}(undef, 1)
490490
liwork = BlasInt(-1)
491-
info = BlasInt[0]
491+
info = Ref(BlasInt(0))
492492
for i = 1:2
493493
ccall(
494494
(@blasfunc($f), liblapack_name),
@@ -504,7 +504,7 @@ for (f, elty) in ((:dsyevd_, :Float64), (:ssyevd_, :Float32))
504504
Ref{BlasInt},
505505
Ptr{BlasInt},
506506
Ref{BlasInt},
507-
Ptr{BlasInt},
507+
Ref{BlasInt},
508508
),
509509
jobz,
510510
uplo,
@@ -519,8 +519,8 @@ for (f, elty) in ((:dsyevd_, :Float64), (:ssyevd_, :Float32))
519519
info,
520520
)
521521

522-
if info[1] != 0
523-
return LinearAlgebra.LAPACKException(info[1])
522+
if info[] != 0
523+
return LinearAlgebra.LAPACKException(info[])
524524
end
525525
if i == 1
526526
lwork = BlasInt(work[1])
@@ -540,14 +540,14 @@ for (f, elty, relty) in
540540
function heevd!(jobz::Char, uplo::Char, A::StridedMatrix{$elty})
541541
n = LinearAlgebra.checksquare(A)
542542
lda = stride(A, 2)
543-
w = Vector{$relty}(n)
543+
w = Vector{$relty}(undef, n)
544544
work = Vector{$elty}(undef, 1)
545545
lwork = BlasInt(-1)
546546
rwork = Vector{$relty}(undef, 1)
547547
lrwork = BlasInt(-1)
548548
iwork = Vector{BlasInt}(undef, 1)
549549
liwork = BlasInt(-1)
550-
info = BlasInt[0]
550+
info = Ref(BlasInt(0))
551551
for i = 1:2
552552
ccall(
553553
(@blasfunc($f), liblapack_name),
@@ -565,7 +565,7 @@ for (f, elty, relty) in
565565
Ref{BlasInt},
566566
Ptr{BlasInt},
567567
Ref{BlasInt},
568-
Ptr{BlasInt},
568+
Ref{BlasInt},
569569
),
570570
jobz,
571571
uplo,
@@ -582,8 +582,8 @@ for (f, elty, relty) in
582582
info,
583583
)
584584

585-
if info[1] != 0
586-
return LinearAlgebra.LAPACKException(info[1])
585+
if info[] != 0
586+
return LinearAlgebra.LAPACKException(info[])
587587
end
588588

589589
if i == 1
@@ -675,23 +675,23 @@ for (f, elty) in ((:dtgevc_, :Float64), (:stgevc_, :Float32))
675675
)
676676
end
677677
elseif howmny == 'S'
678-
mx, mn = extrama(select)
679-
if mx > 1 || nm < 0
678+
mx, mn = extrema(select)
679+
if mx > 1 || mn < 0
680680
throw(ArgumentError("the elements of select must be either 0 or 1"))
681681
end
682-
if sum(howmny) != mm
682+
if sum(select) != mm
683683
throw(
684684
DimensionMismatch(
685-
"the number of columns in the output arrays is $mm, but you have selected $(sum(howmny)) vectors",
685+
"the number of columns in the output arrays is $mm, but you have selected $(sum(select)) vectors",
686686
),
687687
)
688688
end
689689
else
690-
throw(ArgumentError("howmny must be either A, B, or S"))
690+
throw(ArgumentError("howmny must be either A, B, or S but was $howmny"))
691691
end
692692

693-
m = BlasInt[0]
694-
info = BlasInt[0]
693+
m = Ref(BlasInt(0))
694+
info = Ref(BlasInt(0))
695695

696696
ccall(
697697
(@blasfunc($f), liblapack_name),
@@ -710,9 +710,9 @@ for (f, elty) in ((:dtgevc_, :Float64), (:stgevc_, :Float32))
710710
Ptr{$elty},
711711
Ref{BlasInt},
712712
Ref{BlasInt},
713-
Ptr{BlasInt},
713+
Ref{BlasInt},
714714
Ptr{$elty},
715-
Ptr{BlasInt},
715+
Ref{BlasInt},
716716
),
717717
side,
718718
howmny,
@@ -732,11 +732,11 @@ for (f, elty) in ((:dtgevc_, :Float64), (:stgevc_, :Float32))
732732
info,
733733
)
734734

735-
if info[1] != 0
736-
throw(LAPACKException(info[1]))
735+
if info[] != 0
736+
throw(LAPACKException(info[]))
737737
end
738738

739-
return VL, VR, m[1]
739+
return VL, VR, m[]
740740
end
741741

742742
function tgevc!(
@@ -750,7 +750,7 @@ for (f, elty) in ((:dtgevc_, :Float64), (:stgevc_, :Float32))
750750
)
751751

752752
n = LinearAlgebra.checksquare(S)
753-
work = Vector{$elty}(6n)
753+
work = Vector{$elty}(undef, 6n)
754754

755755
return tgevc!(side, howmny, select, S, P, VL, VR, work)
756756
end
@@ -765,26 +765,26 @@ for (f, elty) in ((:dtgevc_, :Float64), (:stgevc_, :Float32))
765765
# No checks here as they are done in method above
766766
n = LinearAlgebra.checksquare(S)
767767
if side == 'L'
768-
VR = Matrix{$elty}(n, 0)
768+
VR = Matrix{$elty}(undef, n, 0)
769769
if howmny == 'A' || howmny == 'B'
770-
VL = Matrix{$elty}(n, n)
770+
VL = Matrix{$elty}(undef, n, n)
771771
else
772-
VL = Matrix{$elty}(n, sum(select))
772+
VL = Matrix{$elty}(undef, n, sum(select))
773773
end
774774
elseif side == 'R'
775-
VL = Matrix{$elty}(n, 0)
775+
VL = Matrix{$elty}(undef, n, 0)
776776
if howmny == 'A' || howmny == 'B'
777-
VR = Matrix{$elty}(n, n)
777+
VR = Matrix{$elty}(undef, n, n)
778778
else
779-
VR = Matrix{$elty}(n, sum(select))
779+
VR = Matrix{$elty}(undef, n, sum(select))
780780
end
781781
else
782782
if howmny == 'A' || howmny == 'B'
783-
VL = Matrix{$elty}(n, n)
784-
VR = Matrix{$elty}(n, n)
783+
VL = Matrix{$elty}(undef, n, n)
784+
VR = Matrix{$elty}(undef, n, n)
785785
else
786-
VL = Matrix{$elty}(n, sum(select))
787-
VR = Matrix{$elty}(n, sum(select))
786+
VL = Matrix{$elty}(undef, n, sum(select))
787+
VR = Matrix{$elty}(undef, n, sum(select))
788788
end
789789
end
790790

test/lapack.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,35 @@ using GenericLinearAlgebra.LAPACK2
6868
# LAPACK's multishift algorithm (the default) seems to be broken
6969
@test !(_vals sort(eigvals(T)))
7070
end
71+
72+
@testset "syevd: eltype=$eltype, uplo=$uplo" for eltype in (Float32, Float64, ComplexF32, ComplexF64), uplo in ('U', 'L')
73+
A = randn(eltype, n, n)
74+
A = A + A'
75+
if eltype <: Real
76+
vals, vecs = LAPACK2.syevd!('V', uplo, copy(A))
77+
else
78+
vals, vecs = LAPACK2.heevd!('V', uplo, copy(A))
79+
end
80+
@test diag(vecs'*A*vecs) eigvals(A)
81+
end
82+
83+
@testset "tgevc: eltype=$eltype, side=$side, howmny=$howmny" for eltype in (Float32, Float64), side in ('L', 'R', 'B'), howmny in ('A', #='B', =#'S')
84+
select = ones(Int, n)
85+
S, P = triu(randn(eltype, n, n)), triu(randn(eltype, n, n))
86+
VL, VR, m = LAPACK2.tgevc!(
87+
side,
88+
howmny,
89+
select,
90+
copy(S),
91+
copy(P),
92+
)
93+
if side ('R', 'B')
94+
w = diag(S*VR) ./ diag(P*VR)
95+
@test S*VR P*VR*Diagonal(w) rtol=sqrt(eps(eltype)) atol=sqrt(eps(eltype))
96+
end
97+
if side ('L', 'B')
98+
w = w = diag(VL'*S) ./ diag(VL'*P)
99+
@test VL'*S Diagonal(w)*VL'*P rtol=sqrt(eps(eltype)) atol=sqrt(eps(eltype))
100+
end
101+
end
71102
end

0 commit comments

Comments
 (0)