Skip to content

Commit a323813

Browse files
committed
some progress
1 parent 6a9b74b commit a323813

File tree

9 files changed

+343
-355
lines changed

9 files changed

+343
-355
lines changed

src/MatrixAlgebraKit.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
module MatrixAlgebraKit
22

33
using LinearAlgebra: LinearAlgebra
4+
using LinearAlgebra: Diagonal
45
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, triu!
56

67
export qr_compact!, qr_full!
7-
export eigh_full!, eigh_vals!, eigh_trunc!
88
export svd_compact!, svd_full!, svd_vals!, svd_trunc!
9+
# export eigh_full!, eigh_vals!, eigh_trunc!
10+
export truncrank, trunctol, TruncationKeepSorted, TruncationKeepFiltered
911

1012
include("auxiliary.jl")
11-
include("backend.jl")
13+
include("algorithms.jl")
14+
include("truncation.jl")
1215
include("yalapack.jl")
1316
include("qr.jl")
1417
include("svd.jl")
15-
include("eigh.jl")
18+
# include("eigh.jl")
1619

1720
end

src/algorithms.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
struct Algorithm{name,K}
2+
kwargs::K
3+
end
4+
5+
macro algdef(name)
6+
esc(quote
7+
const $name{K} = Algorithm{$(QuoteNode(name)),K}
8+
export $name
9+
function $name(; kwargs...)
10+
return $name{typeof(kwargs)}(kwargs)
11+
end
12+
function Base.print(io::IO, alg::$name)
13+
print(io, $name, "(")
14+
next = iterate(alg.kwargs)
15+
isnothing(next) && return print(io, ")")
16+
(k, v), state = next
17+
print(io, "; ", string(k), "=", string(v))
18+
next = iterate(alg.kwargs, state)
19+
while !isnothing(next)
20+
(k, v), state = next
21+
print(io, ", ", string(k), "=", string(v))
22+
next = iterate(alg.kwargs, state)
23+
end
24+
return print(io, ")")
25+
end
26+
end)
27+
end
28+
29+
@algdef LAPACK_QRIteration
30+
@algdef LAPACK_DivideAndConquer
31+
@algdef LAPACK_RobustRepresentations
32+
@algdef LAPACK_HouseholderQR

src/eigh.jl

Lines changed: 72 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,45 @@
1-
# TODO: do not export but mark as public ?
1+
# TODO: export? or not export but mark as public ?
22
function eigh!(A::AbstractMatrix, args...; kwargs...)
33
return eigh_full!(A, args...; kwargs...)
44
end
55

6-
function eigh_full!(A::AbstractMatrix,
7-
D::AbstractVector=similar(A, real(eltype(A)), size(A, 1)),
8-
V::AbstractMatrix=similar(A, size(A));
9-
kwargs...)
10-
return eigh_full!(A, D, V, default_backend(eigh_full!, A; kwargs...); kwargs...)
6+
function eigh_full!(A::AbstractMatrix, DV=eigh_full_init(A); kwargs...)
7+
return eigh_full!(A, DV, default_algorithm(eigh_full!, A; kwargs...))
118
end
12-
function eigh_vals!(A::AbstractMatrix,
13-
D::AbstractVector=similar(A, real(eltype(A)), size(A, 1));
14-
kwargs...)
15-
return eigh_vals!(A, D, default_backend(eigh_vals!, A; kwargs...); kwargs...)
9+
function eigh_vals!(A::AbstractMatrix, D=eigh_vals_init(A); kwargs...)
10+
return eigh_vals!(A, D, default_algorithm(eigh_vals!, A; kwargs...))
1611
end
17-
function eigh_trunc!(A::AbstractMatrix;
18-
kwargs...)
19-
return eigh_trunc!(A, default_backend(eigh_trunc!, A; kwargs...); kwargs...)
12+
function eigh_trunc!(A::AbstractMatrix; kwargs...)
13+
return eigh_trunc!(A, default_algorithm(eigh_trunc!, A; kwargs...))
2014
end
2115

22-
function default_backend(::typeof(eigh_full!), A::AbstractMatrix; kwargs...)
23-
return default_eigh_backend(A; kwargs...)
16+
function eigh_full_init(A::AbstractMatrix)
17+
n = size(A, 1) # square check will happen later
18+
D = similar(A, real(eltype(A)), n)
19+
V = similar(A, (n, n))
20+
return (D, V)
2421
end
25-
function default_backend(::typeof(eigh_vals!), A::AbstractMatrix; kwargs...)
26-
return default_eigh_backend(A; kwargs...)
22+
function eigh_vals_init(A::AbstractMatrix)
23+
n = size(A, 1) # square check will happen later
24+
D = similar(A, real(eltype(A)), n)
25+
return D
26+
end
27+
28+
function default_algorithm(::typeof(eigh_full!), A::AbstractMatrix; kwargs...)
29+
return default_eigh_algorithm(A; kwargs...)
2730
end
28-
function default_backend(::typeof(eigh_trunc!), A::AbstractMatrix; kwargs...)
29-
return default_eigh_backend(A; kwargs...)
31+
function default_algorithm(::typeof(eigh_vals!), A::AbstractMatrix; kwargs...)
32+
return default_eigh_algorithm(A; kwargs...)
33+
end
34+
function default_algorithm(::typeof(eigh_trunc!), A::AbstractMatrix; kwargs...)
35+
return default_eigh_algorithm(A; kwargs...)
3036
end
3137

32-
function default_eigh_backend(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
33-
return LAPACKBackend()
38+
function default_eigh_algorithm(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
39+
return LAPACK_RobustRepresentations(; kwargs...)
3440
end
3541

36-
function check_eigh_full_input(A, D, V)
42+
function check_eigh_full_input(A::AbstractMatrix, (D, V))
3743
m, n = size(A)
3844
m == n || throw(ArgumentError("Eigenvalue decompsition requires square matrix"))
3945
size(D) == (n,) ||
@@ -42,82 +48,66 @@ function check_eigh_full_input(A, D, V)
4248
throw(DimensionMismatch("Eigenvector matrix `V` must have size equal to A"))
4349
return nothing
4450
end
45-
function check_eigh_vals_input(A, D)
51+
function check_eigh_vals_input(A::AbstractMatrix, (D, V))
4652
m, n = size(A)
4753
m == n || throw(ArgumentError("Eigenvalue decompsition requires square matrix"))
4854
size(D) == (n,) ||
4955
throw(DimensionMismatch("Eigenvalue vector `D` must have length equal to size(A, 1)"))
5056
return nothing
5157
end
5258

53-
@static if VERSION >= v"1.12-DEV.0"
54-
const RobustRepresentations = LinearAlgebra.RobustRepresentations
55-
else
56-
struct RobustRepresentations end
57-
end
58-
59-
function eigh_full!(A::AbstractMatrix,
60-
D::AbstractVector,
61-
V::AbstractMatrix,
62-
backend::LAPACKBackend;
63-
alg=RobustRepresentations(),
64-
kwargs...)
65-
check_eigh_full_input(A, D, V)
66-
if alg == RobustRepresentations()
67-
YALAPACK.heevr!(A, D, V; kwargs...)
68-
elseif alg == LinearAlgebra.DivideAndConquer()
69-
YALAPACK.heevd!(A, D, V; kwargs...)
70-
elseif alg == LinearAlgebra.QRIteration()
71-
YALAPACK.heev!(A, D, V; kwargs...)
59+
const LAPACK_EighAlgorithm = Union{LAPACK_RobustRepresentations,LAPACK_QRIteration,
60+
LAPACK_DivideAndConquer}
61+
function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm)
62+
check_eigh_full_input(A, DV)
63+
D, V = DV
64+
if alg isa LAPACK_RobustRepresentations
65+
YALAPACK.heevr!(A, D, V; alg.kwargs...)
66+
elseif alg isa LAPACK_DivideAndConquer
67+
YALAPACK.heevd!(A, D, V; alg.kwargs...)
7268
else
73-
throw(ArgumentError("Unknown LAPACK eigenvalue algorithm $alg"))
69+
YALAPACK.heev!(A, D, V; alg.kwargs...)
7470
end
7571
return D, V
7672
end
7773

78-
function eigh_vals!(A::AbstractMatrix,
79-
D::AbstractVector,
80-
backend::LAPACKBackend;
81-
alg=RobustRepresentations(),
82-
kwargs...)
74+
function eigh_vals!(A::AbstractMatrix, D, alg::LAPACK_EighAlgorithm)
8375
check_eigh_vals_input(A, D)
8476
V = similar(A, (size(A, 1), 0))
85-
if alg == RobustRepresentations()
86-
YALAPACK.heevr!(A, D, V; kwargs...)
87-
elseif alg == LinearAlgebra.DivideAndConquer()
88-
YALAPACK.heevd!(A, D, V; kwargs...)
89-
elseif alg == LinearAlgebra.QRIteration()
90-
YALAPACK.heev!(A, D, V; kwargs...)
77+
if alg isa LAPACK_RobustRepresentations
78+
YALAPACK.heevr!(A, D, V; alg.kwargs...)
79+
elseif alg isa LAPACK_DivideAndConquer
80+
YALAPACK.heevd!(A, D, V; alg.kwargs...)
9181
else
92-
throw(ArgumentError("Unknown LAPACK eigenvalue algorithm $alg"))
82+
YALAPACK.heev!(A, D, V; alg.kwargs...)
9383
end
94-
return D
84+
return D, V
9585
end
9686

9787
# for eigh_trunc!, it doesn't make sense to preallocate D and V as we don't know their sizes
98-
function eigh_trunc!(A::AbstractMatrix,
99-
backend::LAPACKBackend;
100-
alg=RobustRepresentations(),
101-
atol=zero(real(eltype(A))),
102-
rtol=zero(real(eltype(A))),
103-
rank=size(A, 1),
104-
kwargs...)
105-
if alg == RobustRepresentations()
106-
D, V = YALAPACK.heevr!(A; kwargs...)
107-
elseif alg == LinearAlgebra.DivideAndConquer()
108-
D, V = YALAPACK.heevd!(A; kwargs...)
109-
elseif alg == LinearAlgebra.QRIteration()
110-
D, V = YALAPACK.heev!(A; kwargs...)
111-
else
112-
throw(ArgumentError("Unknown LAPACK eigenvalue algorithm $alg"))
113-
end
114-
# eigenvalues are sorted in ascending order
115-
# TODO: do we assume that they are positive, or should we check for this?
116-
# or do we want to truncate based on absolute value and thus sort differently?
117-
n = length(D)
118-
tol = convert(eltype(D), max(atol, rtol * D[n]))
119-
s = max(n - rank + 1, findfirst(>=(tol), D))
120-
# TODO: do we want views here, such that we do not need extra allocations if we later
121-
# copy them into other storage
122-
return D[n:-1:s], V[:, n:-1:s]
123-
end
88+
# function eigh_trunc!(A::AbstractMatrix,
89+
# backend::LAPACKBackend;
90+
# alg=RobustRepresentations(),
91+
# atol=zero(real(eltype(A))),
92+
# rtol=zero(real(eltype(A))),
93+
# rank=size(A, 1),
94+
# kwargs...)
95+
# if alg == RobustRepresentations()
96+
# D, V = YALAPACK.heevr!(A; kwargs...)
97+
# elseif alg == LinearAlgebra.DivideAndConquer()
98+
# D, V = YALAPACK.heevd!(A; kwargs...)
99+
# elseif alg == LinearAlgebra.QRIteration()
100+
# D, V = YALAPACK.heev!(A; kwargs...)
101+
# else
102+
# throw(ArgumentError("Unknown LAPACK eigenvalue algorithm $alg"))
103+
# end
104+
# # eigenvalues are sorted in ascending order
105+
# # TODO: do we assume that they are positive, or should we check for this?
106+
# # or do we want to truncate based on absolute value and thus sort differently?
107+
# n = length(D)
108+
# tol = convert(eltype(D), max(atol, rtol * D[n]))
109+
# s = max(n - rank + 1, findfirst(>=(tol), D))
110+
# # TODO: do we want views here, such that we do not need extra allocations if we later
111+
# # copy them into other storage
112+
# return D[n:-1:s], V[:, n:-1:s]
113+
# end

src/qr.jl

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,72 @@
1-
function qr_full!(A::AbstractMatrix,
2-
Q::AbstractMatrix=similar(A, (size(A, 1), size(A, 1))),
3-
R::AbstractMatrix=similar(A, (size(A, 1), size(A, 2)));
4-
kwargs...)
5-
return qr_full!(A, Q, R, default_backend(qr_full!, A; kwargs...); kwargs...)
1+
function qr_full!(A::AbstractMatrix, QR=qr_full_init(A); kwargs...)
2+
return qr_full!(A, QR, default_algorithm(qr_full!, A; kwargs...))
63
end
7-
function qr_compact!(A::AbstractMatrix,
8-
Q::AbstractMatrix=similar(A, (size(A, 1), size(A, 1))),
9-
R::AbstractMatrix=similar(A, (size(A, 1), size(A, 2)));
10-
kwargs...)
11-
return qr_compact!(A, Q, R, default_backend(qr_compact!, A; kwargs...); kwargs...)
4+
function qr_compact!(A::AbstractMatrix, QR=qr_compact_init(A); kwargs...)
5+
return qr_compact!(A, QR, default_algorithm(qr_compact!, A; kwargs...))
126
end
137

14-
function default_backend(::typeof(qr_full!), A::AbstractMatrix; kwargs...)
15-
return default_qr_backend(A; kwargs...)
8+
function qr_full_init(A::AbstractMatrix)
9+
m, n = size(A)
10+
Q = similar(A, (m, m))
11+
R = similar(A, (m, n))
12+
return (Q, R)
13+
end
14+
function qr_compact_init(A::AbstractMatrix)
15+
m, n = size(A)
16+
minmn = min(m, n)
17+
Q = similar(A, (m, minmn))
18+
R = similar(A, (minmn, n))
19+
return (Q, R)
20+
end
21+
22+
function default_algorithm(::typeof(qr_full!), A::AbstractMatrix; kwargs...)
23+
return default_qr_algorithm(A; kwargs...)
1624
end
17-
function default_backend(::typeof(qr_compact!), A::AbstractMatrix; kwargs...)
18-
return default_qr_backend(A; kwargs...)
25+
function default_algorithm(::typeof(qr_compact!), A::AbstractMatrix; kwargs...)
26+
return default_qr_algorithm(A; kwargs...)
1927
end
2028

21-
function default_qr_backend(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
22-
return LAPACKBackend()
29+
function default_qr_algorithm(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
30+
return LAPACK_HouseholderQR(; kwargs...)
2331
end
2432

25-
function check_qr_full_input(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix)
33+
function check_qr_full_input(A::AbstractMatrix, QR)
2634
m, n = size(A)
27-
size(Q) == (m, m) ||
35+
Q, R = QR
36+
(Q isa AbstractMatrix && eltype(Q) == eltype(A) && size(Q) == (m, m)) ||
2837
throw(DimensionMismatch("Full unitary matrix `Q` must be square with equal number of rows as A"))
29-
isempty(R) || size(R) == (m, n) ||
38+
(R isa AbstractMatrix && eltype(R) == eltype(A) && (isempty(R) || size(R) == (m, n))) ||
3039
throw(DimensionMismatch("Upper triangular matrix `R` must have size equal to A"))
3140
return nothing
3241
end
33-
function check_qr_compact_input(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix)
42+
function check_qr_compact_input(A::AbstractMatrix, QR)
3443
m, n = size(A)
3544
if n <= m
36-
size(Q) == (m, n) ||
45+
Q, R = QR
46+
(Q isa AbstractMatrix && eltype(Q) == eltype(A) && size(Q) == (m, n)) ||
3747
throw(DimensionMismatch("Isometric `Q` must have size equal to A"))
38-
isempty(R) || size(R) == (n, n) ||
48+
(R isa AbstractMatrix && eltype(R) == eltype(A) &&
49+
(isempty(R) || size(R) == (n, n))) ||
3950
throw(DimensionMismatch("Upper triangular matrix `R` must be square with equal number of columns as A"))
4051
else
41-
check_qr_full_input(A, Q, R)
52+
check_qr_full_input(A, QR)
4253
end
4354
end
4455

45-
function qr_full!(A::AbstractMatrix,
46-
Q::AbstractMatrix,
47-
R::AbstractMatrix,
48-
backend::LAPACKBackend;
49-
positive=false,
50-
pivoted=false,
51-
blocksize=((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A)))
52-
check_qr_full_input(A, Q, R)
53-
_unsafe_qr!(A, Q, R; positive=positive, pivoted=pivoted, blocksize=blocksize)
56+
function qr_full!(A::AbstractMatrix, QR, alg::LAPACK_HouseholderQR)
57+
check_qr_full_input(A, QR)
58+
Q, R = QR
59+
_lapack_qr!(A, Q, R; alg.kwargs...)
5460
return Q, R
5561
end
56-
57-
function qr_compact!(A::AbstractMatrix,
58-
Q::AbstractMatrix,
59-
R::AbstractMatrix,
60-
backend::LAPACKBackend;
61-
positive=false,
62-
pivoted=false,
63-
blocksize=((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A)))
64-
check_qr_compact_input(A, Q, R)
65-
_unsafe_qr!(A, Q, R; positive=positive, pivoted=pivoted, blocksize=blocksize)
62+
function qr_compact!(A::AbstractMatrix, QR, alg::LAPACK_HouseholderQR)
63+
check_qr_compact_input(A, QR)
64+
Q, R = QR
65+
_lapack_qr!(A, Q, R; alg.kwargs...)
6666
return Q, R
6767
end
6868

69-
function _unsafe_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
69+
function _lapack_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
7070
positive=false,
7171
pivoted=false,
7272
blocksize=((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A)))

0 commit comments

Comments
 (0)