Skip to content

Commit 5d9474a

Browse files
matteoseclimaleadt
andauthored
Expand eigen() and add eig[vals,vecs]() (#2787)
- Expand LinearAlgebra.eigen() to handle non-symmetric matrices via recent `Xgeev` - Add LinearAlgebra.eigvals() and LinearAlgebra.eigvecs() Co-authored-by: Tim Besard <[email protected]>
1 parent bb0c275 commit 5d9474a

File tree

2 files changed

+165
-6
lines changed

2 files changed

+165
-6
lines changed

lib/cusolver/linalg.jl

Lines changed: 84 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,27 +117,105 @@ Base.copyto!(dst::Symmetric{<:Any,<:CuMatrix}, src::Symmetric{<:Any,<:CuMatrix})
117117
Base.copyto!(dst::Hermitian{<:Any,<:CuMatrix}, src::Hermitian{<:Any,<:CuMatrix}) =
118118
@invoke copyto!(dst::Hermitian, src::Hermitian)
119119

120-
# eigenvalues
120+
# eigen
121121

122122
function LinearAlgebra.eigen(A::Symmetric{T,<:CuMatrix}) where {T<:BlasReal}
123123
A2 = copy(A.data)
124-
Eigen(syevd!('V', 'U', A2)...)
124+
return Eigen(syevd!('V', 'U', A2)...)
125125
end
126126
function LinearAlgebra.eigen(A::Hermitian{T,<:CuMatrix}) where {T<:BlasComplex}
127127
A2 = copy(A.data)
128-
Eigen(heevd!('V', 'U', A2)...)
128+
return Eigen(heevd!('V', 'U', A2)...)
129129
end
130130
function LinearAlgebra.eigen(A::Hermitian{T,<:CuMatrix}) where {T<:BlasReal}
131-
eigen(Symmetric(A))
131+
return eigen(Symmetric(A))
132132
end
133133

134134
function LinearAlgebra.eigen(A::CuMatrix{T}) where {T<:BlasReal}
135135
A2 = copy(A)
136-
issymmetric(A) ? Eigen(syevd!('V', 'U', A2)...) : error("GPU eigensolver supports only Hermitian or Symmetric matrices.")
136+
if issymmetric(A)
137+
return Eigen(syevd!('V', 'U', A2)...)
138+
else
139+
W, _, VR = Xgeev!('N', 'V', A2)
140+
C = Complex{T}
141+
U = CuMatrix{C}([1.0 1.0; im -im])
142+
VR = CuMatrix{C}(VR)
143+
h_W = collect(W)
144+
n = length(W)
145+
j = 1
146+
while j <= n
147+
if imag(h_W[j]) == 0
148+
j += 1
149+
else
150+
VR[:, j:(j + 1)] .= VR[:, j:(j + 1)] * U
151+
j += 2
152+
end
153+
end
154+
return Eigen(W, VR)
155+
end
137156
end
138157
function LinearAlgebra.eigen(A::CuMatrix{T}) where {T<:BlasComplex}
139158
A2 = copy(A)
140-
ishermitian(A) ? Eigen(heevd!('V', 'U', A2)...) : error("GPU eigensolver supports only Hermitian or Symmetric matrices.")
159+
if ishermitian(A)
160+
return Eigen(heevd!('V', 'U', A2)...)
161+
else
162+
r = Xgeev!('N', 'V', A2)
163+
return Eigen(r[1], r[3])
164+
end
165+
end
166+
167+
# eigvals
168+
169+
function LinearAlgebra.eigvals(A::Symmetric{T, <:CuMatrix}) where {T <: BlasReal}
170+
A2 = copy(A.data)
171+
return syevd!('N', 'U', A2)
172+
end
173+
function LinearAlgebra.eigvals(A::Hermitian{T, <:CuMatrix}) where {T <: BlasComplex}
174+
A2 = copy(A.data)
175+
return heevd!('N', 'U', A2)
176+
end
177+
function LinearAlgebra.eigvals(A::Hermitian{T, <:CuMatrix}) where {T <: BlasReal}
178+
return eigvals(Symmetric(A))
179+
end
180+
181+
function LinearAlgebra.eigvals(A::CuMatrix{T}) where {T <: BlasReal}
182+
A2 = copy(A)
183+
if issymmetric(A)
184+
return syevd!('N', 'U', A2)
185+
else
186+
return Xgeev!('N', 'N', A2)[1]
187+
end
188+
end
189+
function LinearAlgebra.eigvals(A::CuMatrix{T}) where {T <: BlasComplex}
190+
A2 = copy(A)
191+
if ishermitian(A)
192+
return heevd!('N', 'U', A2)
193+
else
194+
return Xgeev!('N', 'N', A2)[1]
195+
end
196+
end
197+
198+
# eigvecs
199+
200+
function LinearAlgebra.eigvecs(A::Symmetric{T, <:CuMatrix}) where {T <: BlasReal}
201+
E = eigen(A)
202+
return E.vectors
203+
end
204+
function LinearAlgebra.eigvecs(A::Hermitian{T, <:CuMatrix}) where {T <: BlasComplex}
205+
E = eigen(A)
206+
return E.vectors
207+
end
208+
function LinearAlgebra.eigvecs(A::Hermitian{T, <:CuMatrix}) where {T <: BlasReal}
209+
return eigvecs(Symmetric(A))
210+
end
211+
212+
function LinearAlgebra.eigvecs(A::CuMatrix{T}) where {T <: BlasReal}
213+
E = eigen(A)
214+
return E.vectors
215+
end
216+
function LinearAlgebra.eigvecs(A::CuMatrix{T}) where {T <: BlasComplex}
217+
E = eigen(A)
218+
return E.vectors
141219
end
142220

143221
# matrix functions

test/libraries/cusolver/dense.jl

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,20 @@ p = 5
99
l = 13
1010
k = 1
1111

12+
# Adapted from LinearAlgebra.sorteig!().
13+
# Warning: not very efficient, but works.
14+
eigsortby::Real) = λ
15+
eigsortby::Complex) = (real(λ), imag(λ))
16+
function sorteig!::AbstractVector, X::AbstractMatrix, sortby::Union{Function, Nothing} = eigsortby)
17+
if sortby !== nothing # && !issorted(λ, by=sortby)
18+
p = sortperm(λ; by = sortby)
19+
λ .= λ[p] # permute!(λ, p)
20+
X .= X[:, p] # Base.permutecols!!(X, p)
21+
end
22+
return λ, X
23+
end
24+
sorteig!::AbstractVector, sortby::Union{Function, Nothing} = eigsortby) = sortby === nothing ? λ : sort!(λ, by = sortby)
25+
1226
@testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
1327
@testset "gesv!" begin
1428
@testset "irs_precision = AUTO" begin
@@ -315,6 +329,39 @@ k = 1
315329
end
316330
end
317331

332+
# Note: Xgeev was introduced in CUDA 12.6.2 / CUSOLVER 11.7.1
333+
if CUSOLVER.version() >= v"11.7.1"
334+
@testset "geev!" begin
335+
local d_W, d_V
336+
337+
A = rand(elty,m,m)
338+
d_A = CuArray(A)
339+
Eig = eigen(A)
340+
d_eig = eigen(d_A)
341+
sorteig!(d_eig.values, d_eig.vectors)
342+
@test Eig.values collect(d_eig.values)
343+
h_V = collect(d_eig.vectors)
344+
h_V⁻¹ = inv(h_V)
345+
@test abs.(h_V⁻¹*Eig.vectors) I
346+
347+
A = rand(elty,m,m)
348+
d_A = CuArray(A)
349+
W = eigvals(A)
350+
d_W = eigvals(d_A)
351+
sorteig!(d_W)
352+
@test W collect(d_W)
353+
354+
A = rand(elty,m,m)
355+
d_A = CuArray(A)
356+
V = eigvecs(A)
357+
d_W = eigvals(d_A)
358+
d_V = eigvecs(d_A)
359+
sorteig!(d_W, d_V)
360+
V⁻¹ = inv(V)
361+
@test abs.(V⁻¹*collect(d_V)) I
362+
end
363+
end
364+
318365
@testset "syevd!" begin
319366
A = rand(elty,m,m)
320367
A += A'
@@ -356,6 +403,7 @@ k = 1
356403
d_A = CuArray(A)
357404
Eig = eigen(LinearAlgebra.Hermitian(A))
358405
d_eig = eigen(d_A)
406+
sorteig!(d_eig.values, d_eig.vectors)
359407
@test Eig.values collect(d_eig.values)
360408
d_eig = eigen(LinearAlgebra.Hermitian(d_A))
361409
@test Eig.values collect(d_eig.values)
@@ -369,6 +417,39 @@ k = 1
369417
@test abs.(Eig.vectors'*h_V) I
370418
end
371419

420+
A = rand(elty,m,m)
421+
A += A'
422+
d_A = CuArray(A)
423+
W = eigvals(LinearAlgebra.Hermitian(A))
424+
d_W = eigvals(d_A)
425+
sorteig!(d_W)
426+
@test W collect(d_W)
427+
d_W = eigvals(LinearAlgebra.Hermitian(d_A))
428+
@test W collect(d_W)
429+
if elty <: Real
430+
W = eigvals(LinearAlgebra.Symmetric(A))
431+
d_W = eigvals(LinearAlgebra.Symmetric(d_A))
432+
@test W collect(d_W)
433+
end
434+
435+
A = rand(elty,m,m)
436+
A += A'
437+
d_A = CuArray(A)
438+
V = eigvecs(LinearAlgebra.Hermitian(A))
439+
d_W = eigvals(d_A)
440+
d_V = eigvecs(d_A)
441+
sorteig!(d_W, d_V)
442+
h_V = collect(d_V)
443+
@test abs.(V'*h_V) I
444+
d_V = eigvecs(LinearAlgebra.Hermitian(d_A))
445+
h_V = collect(d_V)
446+
@test abs.(V'*h_V) I
447+
if elty <: Real
448+
V = eigvecs(LinearAlgebra.Symmetric(A))
449+
d_V = eigvecs(LinearAlgebra.Symmetric(d_A))
450+
h_V = collect(d_V)
451+
@test abs.(V'*h_V) I
452+
end
372453
end
373454

374455
@testset "sygvd!" begin

0 commit comments

Comments
 (0)