Skip to content

Commit fdf6308

Browse files
authored
Expand eigen() and add eig[vals,vecs]()
- Expand LinearAlgebra.eigen() to handle non-symmetric matrices via recent `Xgeev` - Add LinearAlgebra.eigvals() and LinearAlgebra.eigvecs()
1 parent 8b6a2a0 commit fdf6308

File tree

1 file changed

+50
-3
lines changed

1 file changed

+50
-3
lines changed

lib/cusolver/linalg.jl

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ function Base.:\(F::Union{LinearAlgebra.LAPACKFactorizations{<:Any,<:CuArray},
110110
return LinearAlgebra._cut_B(BB, 1:n)
111111
end
112112

113-
# eigenvalues
113+
# eigen
114114

115115
function LinearAlgebra.eigen(A::Symmetric{T,<:CuMatrix}) where {T<:BlasReal}
116116
A2 = copy(A.data)
@@ -126,13 +126,60 @@ end
126126

127127
function LinearAlgebra.eigen(A::CuMatrix{T}) where {T<:BlasReal}
128128
A2 = copy(A)
129-
issymmetric(A) ? Eigen(syevd!('V', 'U', A2)...) : error("GPU eigensolver supports only Hermitian or Symmetric matrices.")
129+
r = Xgeev!('N', 'V', A2)
130+
Eigen(r[1], r[3])
130131
end
131132
function LinearAlgebra.eigen(A::CuMatrix{T}) where {T<:BlasComplex}
132133
A2 = copy(A)
133-
ishermitian(A) ? Eigen(heevd!('V', 'U', A2)...) : error("GPU eigensolver supports only Hermitian or Symmetric matrices.")
134+
r = Xgeev!('N', 'V', A2)
135+
Eigen(r[1], r[3])
134136
end
135137

138+
# eigvals
139+
140+
function LinearAlgebra.eigvals(A::Symmetric{T,<:CuMatrix}) where {T<:BlasReal}
141+
A2 = copy(A.data)
142+
syevd!('N', 'U', A2)[1]
143+
end
144+
function LinearAlgebra.eigvals(A::Hermitian{T,<:CuMatrix}) where {T<:BlasComplex}
145+
A2 = copy(A.data)
146+
heevd!('N', 'U', A2)[1]
147+
end
148+
function LinearAlgebra.eigvals(A::Hermitian{T,<:CuMatrix}) where {T<:BlasReal}
149+
eigvals(Symmetric(A))
150+
end
151+
152+
function LinearAlgebra.eigvals(A::CuMatrix{T}) where {T<:BlasReal}
153+
A2 = copy(A)
154+
Xgeev!('N', 'N', A2)[1]
155+
end
156+
function LinearAlgebra.eigvals(A::CuMatrix{T}) where {T<:BlasComplex}
157+
A2 = copy(A)
158+
Xgeev!('N', 'N', A2)[1]
159+
end
160+
161+
# eigvecs
162+
163+
function LinearAlgebra.eigvecs(A::Symmetric{T,<:CuMatrix}) where {T<:BlasReal}
164+
A2 = copy(A.data)
165+
syevd!('V', 'U', A2)[2]
166+
end
167+
function LinearAlgebra.eigvecs(A::Hermitian{T,<:CuMatrix}) where {T<:BlasComplex}
168+
A2 = copy(A.data)
169+
heevd!('V', 'U', A2)[2]
170+
end
171+
function LinearAlgebra.eigvecs(A::Hermitian{T,<:CuMatrix}) where {T<:BlasReal}
172+
eigvals(Symmetric(A))
173+
end
174+
175+
function LinearAlgebra.eigvecs(A::CuMatrix{T}) where {T<:BlasReal}
176+
A2 = copy(A)
177+
Xgeev!('N', 'V', A2)[3]
178+
end
179+
function LinearAlgebra.eigvecs(A::CuMatrix{T}) where {T<:BlasComplex}
180+
A2 = copy(A)
181+
Xgeev!('N', 'V', A2)[3]
182+
end
136183

137184
# factorizations
138185

0 commit comments

Comments
 (0)