1- # TODO : do not export but mark as public ?
1+ # TODO : export? or not export but mark as public ?
22function eigh!(A:: AbstractMatrix , args... ; kwargs... )
33 return eigh_full!(A, args... ; kwargs... )
44end
55
6- function eigh_full!(A:: AbstractMatrix ,
7- D:: AbstractVector = similar(A, real(eltype(A)), size(A, 1 )),
8- V:: AbstractMatrix = similar(A, size(A));
9- kwargs... )
10- return eigh_full!(A, D, V, default_backend(eigh_full!, A; kwargs... ); kwargs... )
6+ function eigh_full!(A:: AbstractMatrix , DV= eigh_full_init(A); kwargs... )
7+ return eigh_full!(A, DV, default_algorithm(eigh_full!, A; kwargs... ))
118end
12- function eigh_vals!(A:: AbstractMatrix ,
13- D:: AbstractVector = similar(A, real(eltype(A)), size(A, 1 ));
14- kwargs... )
15- return eigh_vals!(A, D, default_backend(eigh_vals!, A; kwargs... ); kwargs... )
9+ function eigh_vals!(A:: AbstractMatrix , D= eigh_vals_init(A); kwargs... )
10+ return eigh_vals!(A, D, default_algorithm(eigh_vals!, A; kwargs... ))
1611end
17- function eigh_trunc!(A:: AbstractMatrix ;
18- kwargs... )
19- return eigh_trunc!(A, default_backend(eigh_trunc!, A; kwargs... ); kwargs... )
12+ function eigh_trunc!(A:: AbstractMatrix ; kwargs... )
13+ return eigh_trunc!(A, default_algorithm(eigh_trunc!, A; kwargs... ))
2014end
2115
22- function default_backend(:: typeof (eigh_full!), A:: AbstractMatrix ; kwargs... )
23- return default_eigh_backend(A; kwargs... )
16+ function eigh_full_init(A:: AbstractMatrix )
17+ n = size(A, 1 ) # square check will happen later
18+ D = similar(A, real(eltype(A)), n)
19+ V = similar(A, (n, n))
20+ return (D, V)
2421end
25- function default_backend(:: typeof (eigh_vals!), A:: AbstractMatrix ; kwargs... )
26- return default_eigh_backend(A; kwargs... )
22+ function eigh_vals_init(A:: AbstractMatrix )
23+ n = size(A, 1 ) # square check will happen later
24+ D = similar(A, real(eltype(A)), n)
25+ return D
26+ end
27+
28+ function default_algorithm(:: typeof (eigh_full!), A:: AbstractMatrix ; kwargs... )
29+ return default_eigh_algorithm(A; kwargs... )
2730end
28- function default_backend(:: typeof (eigh_trunc!), A:: AbstractMatrix ; kwargs... )
29- return default_eigh_backend(A; kwargs... )
31+ function default_algorithm(:: typeof (eigh_vals!), A:: AbstractMatrix ; kwargs... )
32+ return default_eigh_algorithm(A; kwargs... )
33+ end
34+ function default_algorithm(:: typeof (eigh_trunc!), A:: AbstractMatrix ; kwargs... )
35+ return default_eigh_algorithm(A; kwargs... )
3036end
3137
32- function default_eigh_backend (A:: StridedMatrix{T} ; kwargs... ) where {T<: BlasFloat }
33- return LAPACKBackend( )
38+ function default_eigh_algorithm (A:: StridedMatrix{T} ; kwargs... ) where {T<: BlasFloat }
39+ return LAPACK_RobustRepresentations(; kwargs ... )
3440end
3541
36- function check_eigh_full_input(A, D, V)
42+ function check_eigh_full_input(A:: AbstractMatrix , ( D, V) )
3743 m, n = size(A)
3844 m == n || throw(ArgumentError(" Eigenvalue decompsition requires square matrix" ))
3945 size(D) == (n,) ||
@@ -42,82 +48,66 @@ function check_eigh_full_input(A, D, V)
4248 throw(DimensionMismatch(" Eigenvector matrix `V` must have size equal to A" ))
4349 return nothing
4450end
45- function check_eigh_vals_input(A, D )
51+ function check_eigh_vals_input(A:: AbstractMatrix , (D, V) )
4652 m, n = size(A)
4753 m == n || throw(ArgumentError(" Eigenvalue decompsition requires square matrix" ))
4854 size(D) == (n,) ||
4955 throw(DimensionMismatch(" Eigenvalue vector `D` must have length equal to size(A, 1)" ))
5056 return nothing
5157end
5258
53- @static if VERSION >= v" 1.12-DEV.0"
54- const RobustRepresentations = LinearAlgebra. RobustRepresentations
55- else
56- struct RobustRepresentations end
57- end
58-
59- function eigh_full!(A:: AbstractMatrix ,
60- D:: AbstractVector ,
61- V:: AbstractMatrix ,
62- backend:: LAPACKBackend ;
63- alg= RobustRepresentations(),
64- kwargs... )
65- check_eigh_full_input(A, D, V)
66- if alg == RobustRepresentations()
67- YALAPACK. heevr!(A, D, V; kwargs... )
68- elseif alg == LinearAlgebra. DivideAndConquer()
69- YALAPACK. heevd!(A, D, V; kwargs... )
70- elseif alg == LinearAlgebra. QRIteration()
71- YALAPACK. heev!(A, D, V; kwargs... )
59+ const LAPACK_EighAlgorithm = Union{LAPACK_RobustRepresentations,LAPACK_QRIteration,
60+ LAPACK_DivideAndConquer}
61+ function eigh_full!(A:: AbstractMatrix , DV, alg:: LAPACK_EighAlgorithm )
62+ check_eigh_full_input(A, DV)
63+ D, V = DV
64+ if alg isa LAPACK_RobustRepresentations
65+ YALAPACK. heevr!(A, D, V; alg. kwargs... )
66+ elseif alg isa LAPACK_DivideAndConquer
67+ YALAPACK. heevd!(A, D, V; alg. kwargs... )
7268 else
73- throw(ArgumentError( " Unknown LAPACK eigenvalue algorithm $ alg" ) )
69+ YALAPACK . heev!(A, D, V; alg. kwargs ... )
7470 end
7571 return D, V
7672end
7773
78- function eigh_vals!(A:: AbstractMatrix ,
79- D:: AbstractVector ,
80- backend:: LAPACKBackend ;
81- alg= RobustRepresentations(),
82- kwargs... )
74+ function eigh_vals!(A:: AbstractMatrix , D, alg:: LAPACK_EighAlgorithm )
8375 check_eigh_vals_input(A, D)
8476 V = similar(A, (size(A, 1 ), 0 ))
85- if alg == RobustRepresentations()
86- YALAPACK. heevr!(A, D, V; kwargs... )
87- elseif alg == LinearAlgebra. DivideAndConquer()
88- YALAPACK. heevd!(A, D, V; kwargs... )
89- elseif alg == LinearAlgebra. QRIteration()
90- YALAPACK. heev!(A, D, V; kwargs... )
77+ if alg isa LAPACK_RobustRepresentations
78+ YALAPACK. heevr!(A, D, V; alg. kwargs... )
79+ elseif alg isa LAPACK_DivideAndConquer
80+ YALAPACK. heevd!(A, D, V; alg. kwargs... )
9181 else
92- throw(ArgumentError( " Unknown LAPACK eigenvalue algorithm $ alg" ) )
82+ YALAPACK . heev!(A, D, V; alg. kwargs ... )
9383 end
94- return D
84+ return D, V
9585end
9686
9787# for eigh_trunc!, it doesn't make sense to preallocate D and V as we don't know their sizes
98- function eigh_trunc!(A:: AbstractMatrix ,
99- backend:: LAPACKBackend ;
100- alg= RobustRepresentations(),
101- atol= zero(real(eltype(A))),
102- rtol= zero(real(eltype(A))),
103- rank= size(A, 1 ),
104- kwargs... )
105- if alg == RobustRepresentations()
106- D, V = YALAPACK. heevr!(A; kwargs... )
107- elseif alg == LinearAlgebra. DivideAndConquer()
108- D, V = YALAPACK. heevd!(A; kwargs... )
109- elseif alg == LinearAlgebra. QRIteration()
110- D, V = YALAPACK. heev!(A; kwargs... )
111- else
112- throw(ArgumentError(" Unknown LAPACK eigenvalue algorithm $alg " ))
113- end
114- # eigenvalues are sorted in ascending order
115- # TODO : do we assume that they are positive, or should we check for this?
116- # or do we want to truncate based on absolute value and thus sort differently?
117- n = length(D)
118- tol = convert(eltype(D), max(atol, rtol * D[n]))
119- s = max(n - rank + 1 , findfirst(>= (tol), D))
120- # TODO : do we want views here, such that we do not need extra allocations if we later
121- # copy them into other storage
122- return D[n: - 1 : s], V[:, n: - 1 : s]
123- end
88+ # function eigh_trunc!(A::AbstractMatrix,
89+ # backend::LAPACKBackend;
90+ # alg=RobustRepresentations(),
91+ # atol=zero(real(eltype(A))),
92+ # rtol=zero(real(eltype(A))),
93+ # rank=size(A, 1),
94+ # kwargs...)
95+ # if alg == RobustRepresentations()
96+ # D, V = YALAPACK.heevr!(A; kwargs...)
97+ # elseif alg == LinearAlgebra.DivideAndConquer()
98+ # D, V = YALAPACK.heevd!(A; kwargs...)
99+ # elseif alg == LinearAlgebra.QRIteration()
100+ # D, V = YALAPACK.heev!(A; kwargs...)
101+ # else
102+ # throw(ArgumentError("Unknown LAPACK eigenvalue algorithm $alg"))
103+ # end
104+ # # eigenvalues are sorted in ascending order
105+ # # TODO : do we assume that they are positive, or should we check for this?
106+ # # or do we want to truncate based on absolute value and thus sort differently?
107+ # n = length(D)
108+ # tol = convert(eltype(D), max(atol, rtol * D[n]))
109+ # s = max(n - rank + 1, findfirst(>=(tol), D))
110+ # # TODO : do we want views here, such that we do not need extra allocations if we later
111+ # # copy them into other storage
112+ # return D[n:-1:s], V[:, n:-1:s]
113+ # end
0 commit comments