Skip to content

Commit 5559c1e

Browse files
committed
Fix most tests
1 parent 18c43b6 commit 5559c1e

File tree

2 files changed

+54
-35
lines changed

2 files changed

+54
-35
lines changed

lib/cusolver/linalg.jl

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,20 @@ function Base.:\(F::Union{LinearAlgebra.LAPACKFactorizations{<:Any,<:CuArray},
110110
return LinearAlgebra._cut_B(BB, 1:n)
111111
end
112112

113+
# Adapted from LinearAlgebra.sorteig!().
114+
# Warning: not very efficient, but works.
115+
eigsortby::Real) = λ
116+
eigsortby::Complex) = (real(λ),imag(λ))
117+
function sorteig!::AbstractVector, X::AbstractMatrix, sortby::Union{Function,Nothing}=eigsortby)
118+
if sortby !== nothing # && !issorted(λ, by=sortby)
119+
p = sortperm(λ; by=sortby)
120+
λ .= λ[p] # permute!(λ, p)
121+
X .= X[:, p] # Base.permutecols!!(X, p)
122+
end
123+
return λ, X
124+
end
125+
sorteig!::AbstractVector, sortby::Union{Function,Nothing}=eigsortby) = sortby === nothing ? λ : sort!(λ, by=sortby)
126+
113127
# eigen
114128

115129
function LinearAlgebra.eigen(A::Symmetric{T,<:CuMatrix}) where {T<:BlasReal}
@@ -127,58 +141,58 @@ end
127141
function LinearAlgebra.eigen(A::CuMatrix{T}) where {T<:BlasReal}
128142
A2 = copy(A)
129143
r = Xgeev!('N', 'V', A2)
130-
return Eigen(r[1], r[3])
144+
return Eigen(sorteig!(r[1], r[3])...)
131145
end
132146
function LinearAlgebra.eigen(A::CuMatrix{T}) where {T<:BlasComplex}
133147
A2 = copy(A)
134148
r = Xgeev!('N', 'V', A2)
135-
return Eigen(r[1], r[3])
149+
return Eigen(sorteig!(r[1], r[3])...)
136150
end
137151

138152
# eigvals
139153

140154
function LinearAlgebra.eigvals(A::Symmetric{T, <:CuMatrix}) where {T <: BlasReal}
141155
A2 = copy(A.data)
142-
return syevd!('N', 'U', A2)[1]
156+
return syevd!('N', 'U', A2)
143157
end
144158
function LinearAlgebra.eigvals(A::Hermitian{T, <:CuMatrix}) where {T <: BlasComplex}
145159
A2 = copy(A.data)
146-
return heevd!('N', 'U', A2)[1]
160+
return heevd!('N', 'U', A2)
147161
end
148162
function LinearAlgebra.eigvals(A::Hermitian{T, <:CuMatrix}) where {T <: BlasReal}
149163
return eigvals(Symmetric(A))
150164
end
151165

152166
function LinearAlgebra.eigvals(A::CuMatrix{T}) where {T <: BlasReal}
153167
A2 = copy(A)
154-
return Xgeev!('N', 'N', A2)[1]
168+
return sorteig!(Xgeev!('N', 'N', A2)[1])
155169
end
156170
function LinearAlgebra.eigvals(A::CuMatrix{T}) where {T <: BlasComplex}
157171
A2 = copy(A)
158-
return Xgeev!('N', 'N', A2)[1]
172+
return sorteig!(Xgeev!('N', 'N', A2)[1])
159173
end
160174

161175
# eigvecs
162176

163177
function LinearAlgebra.eigvecs(A::Symmetric{T, <:CuMatrix}) where {T <: BlasReal}
164-
A2 = copy(A.data)
165-
return syevd!('V', 'U', A2)[2]
178+
E = eigen(A)
179+
return E.vectors
166180
end
167181
function LinearAlgebra.eigvecs(A::Hermitian{T, <:CuMatrix}) where {T <: BlasComplex}
168-
A2 = copy(A.data)
169-
return heevd!('V', 'U', A2)[2]
182+
E = eigen(A)
183+
return E.vectors
170184
end
171185
function LinearAlgebra.eigvecs(A::Hermitian{T, <:CuMatrix}) where {T <: BlasReal}
172186
return eigvecs(Symmetric(A))
173187
end
174188

175189
function LinearAlgebra.eigvecs(A::CuMatrix{T}) where {T <: BlasReal}
176-
A2 = copy(A)
177-
return Xgeev!('N', 'V', A2)[3]
190+
E = eigen(A)
191+
return E.vectors
178192
end
179193
function LinearAlgebra.eigvecs(A::CuMatrix{T}) where {T <: BlasComplex}
180-
A2 = copy(A)
181-
return Xgeev!('N', 'V', A2)[3]
194+
E = eigen(A)
195+
return E.vectors
182196
end
183197

184198
# factorizations

test/libraries/cusolver/dense.jl

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -316,23 +316,27 @@ k = 1
316316
end
317317

318318
@testset "geev!" begin
319-
A = rand(elty,m,m)
320-
d_A = CuArray(A)
319+
## Note: we have Xgeev in dense_generic.jl, but no geev in dense.jl.
320+
# A = rand(elty,m,m)
321+
# d_A = CuArray(A)
321322
local d_W, d_V
322-
d_W, d_V = CUSOLVER.Xgeev!('N','V', d_A)
323-
d_W_b, d_V_b = LAPACK.geev!('N','V', CuArray(A))
324-
@test d_W d_W_b
325-
@test d_V d_V_b
326-
h_W = collect(d_W)
327-
h_V = collect(d_V)
328-
h_V⁻¹ = inv(h_V)
329-
Eig = eigen(A)
330-
@test Eig.values h_W
331-
@test abs.(Eig.vectors*h_V⁻¹) I
332-
d_A = CuArray(A)
333-
d_W = CUSOLVER.Xgeev!('N','N', d_A)
334-
h_W = collect(d_W)
335-
@test Eig.values h_W
323+
# d_W, _, d_V = CUSOLVER.Xgeev!('N','V', d_A)
324+
# # d_W_b, _, d_V_b = LAPACK.geev!('N','V', CuArray(A))
325+
# # @test d_W ≈ d_W_b
326+
# # @test d_V ≈ d_V_b
327+
# W_b, _, V_b = LAPACK.geev!('N','V', A)
328+
# @test collect(d_W) ≈ W_b
329+
# @test collect(d_V) ≈ V_b
330+
# h_W = collect(d_W)
331+
# h_V = collect(d_V)
332+
# h_V⁻¹ = inv(h_V)
333+
# Eig = eigen(A)
334+
# @test Eig.values ≈ h_W
335+
# @test abs.(Eig.vectors*h_V⁻¹) ≈ I
336+
# d_A = CuArray(A)
337+
# d_W = CUSOLVER.Xgeev!('N','N', d_A)
338+
# h_W = collect(d_W)
339+
# @test Eig.values ≈ h_W
336340

337341
A = rand(elty,m,m)
338342
d_A = CuArray(A)
@@ -429,16 +433,17 @@ k = 1
429433
A += A'
430434
d_A = CuArray(A)
431435
V = eigvecs(LinearAlgebra.Hermitian(A))
432-
V⁻¹ = inv(V)
433436
d_V = eigvecs(d_A)
434-
@test abs.(collect(d_V)*V⁻¹) I
437+
h_V = collect(d_V)
438+
@test abs.(V'*h_V) I
435439
d_V = eigvecs(LinearAlgebra.Hermitian(d_A))
436-
@test abs.(collect(d_V)*V⁻¹) I
440+
h_V = collect(d_V)
441+
@test abs.(V'*h_V) I
437442
if elty <: Real
438443
V = eigvecs(LinearAlgebra.Symmetric(A))
439-
V⁻¹ = inv(V)
440444
d_V = eigvecs(LinearAlgebra.Symmetric(d_A))
441-
@test abs.(collect(d_V)*V⁻¹) I
445+
h_V = collect(d_V)
446+
@test abs.(V'*h_V) I
442447
end
443448
end
444449

0 commit comments

Comments
 (0)