Skip to content

Commit 2d5b83a

Browse files
committed
some progress with factorisations and tests
1 parent 3d66e5d commit 2d5b83a

File tree

13 files changed

+582
-155
lines changed

13 files changed

+582
-155
lines changed

Project.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,18 @@ version = "0.1.0"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88

99
[compat]
10+
Aqua = "0.6, 0.7, 0.8"
11+
JET = "0.9"
1012
LinearAlgebra = "1"
13+
Test = "1"
14+
TestExtras = "0.2,0.3"
1115
julia = "1.10"
1216

1317
[extras]
1418
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
1519
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1620
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
21+
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
1722

1823
[targets]
19-
test = ["Aqua", "JET", "Test"]
24+
test = ["Aqua", "JET", "Test", "TestExtras"]

src/MatrixAlgebraKit.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,14 @@ module MatrixAlgebraKit
33
using LinearAlgebra: LinearAlgebra
44
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, triu!
55

6+
export qr_compact!, qr_full!
7+
export eigh_full!, eigh_vals!, eigh_trunc!
8+
export svd_compact!, svd_full!, svd_vals!, svd_trunc!
9+
610
include("auxiliary.jl")
711
include("backend.jl")
812
include("yalapack.jl")
13+
include("qr.jl")
914
include("svd.jl")
1015
include("eigh.jl")
1116

src/backend.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
struct LAPACKBackend end
1+
struct LAPACKBackend end

src/eigh.jl

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
# `eigh!`` is a simple wrapper for `eigh_full!`
2-
function eigh!(A::AbstractMatrix,
3-
D::AbstractVector=similar(A, real(eltype(A)), size(A, 1)),
4-
V::AbstractMatrix=similar(A, size(A));
5-
kwargs...)
6-
return eigh_full!(A, D, V; kwargs...)
1+
# TODO: do not export but mark as public ?
2+
function eigh!(A::AbstractMatrix, args...; kwargs...)
3+
return eigh_full!(A, args...; kwargs...)
74
end
85

96
function eigh_full!(A::AbstractMatrix,
@@ -12,6 +9,11 @@ function eigh_full!(A::AbstractMatrix,
129
kwargs...)
1310
return eigh_full!(A, D, V, default_backend(eigh_full!, A; kwargs...); kwargs...)
1411
end
12+
function eigh_vals!(A::AbstractMatrix,
13+
D::AbstractVector=similar(A, real(eltype(A)), size(A, 1));
14+
kwargs...)
15+
return eigh_vals!(A, D, default_backend(eigh_vals!, A; kwargs...); kwargs...)
16+
end
1517
function eigh_trunc!(A::AbstractMatrix;
1618
kwargs...)
1719
return eigh_trunc!(A, default_backend(eigh_trunc!, A; kwargs...); kwargs...)
@@ -20,6 +22,9 @@ end
2022
function default_backend(::typeof(eigh_full!), A::AbstractMatrix; kwargs...)
2123
return default_eigh_backend(A; kwargs...)
2224
end
25+
function default_backend(::typeof(eigh_vals!), A::AbstractMatrix; kwargs...)
26+
return default_eigh_backend(A; kwargs...)
27+
end
2328
function default_backend(::typeof(eigh_trunc!), A::AbstractMatrix; kwargs...)
2429
return default_eigh_backend(A; kwargs...)
2530
end
@@ -37,6 +42,13 @@ function check_eigh_full_input(A, D, V)
3742
throw(DimensionMismatch("Eigenvector matrix `V` must have size equal to A"))
3843
return nothing
3944
end
45+
function check_eigh_vals_input(A, D)
46+
m, n = size(A)
47+
m == n || throw(ArgumentError("Eigenvalue decompsition requires square matrix"))
48+
size(D) == (n,) ||
49+
throw(DimensionMismatch("Eigenvalue vector `D` must have length equal to size(A, 1)"))
50+
return nothing
51+
end
4052

4153
@static if VERSION >= v"1.12-DEV.0"
4254
const RobustRepresentations = LinearAlgebra.RobustRepresentations
@@ -58,17 +70,37 @@ function eigh_full!(A::AbstractMatrix,
5870
elseif alg == LinearAlgebra.QRIteration()
5971
YALAPACK.heev!(A, D, V; kwargs...)
6072
else
61-
throw(ArgumentError("Unknown algorithm $alg"))
73+
throw(ArgumentError("Unknown LAPACK eigenvalue algorithm $alg"))
6274
end
6375
return D, V
6476
end
6577

66-
# for eigh_trunc!, it doesn't make sense to preallocate U, S, Vᴴ as we don't know their sizes
78+
function eigh_vals!(A::AbstractMatrix,
79+
D::AbstractVector,
80+
backend::LAPACKBackend;
81+
alg=RobustRepresentations(),
82+
kwargs...)
83+
check_eigh_vals_input(A, D)
84+
V = similar(A, (size(A, 1), 0))
85+
if alg == RobustRepresentations()
86+
YALAPACK.heevr!(A, D, V; kwargs...)
87+
elseif alg == LinearAlgebra.DivideAndConquer()
88+
YALAPACK.heevd!(A, D, V; kwargs...)
89+
elseif alg == LinearAlgebra.QRIteration()
90+
YALAPACK.heev!(A, D, V; kwargs...)
91+
else
92+
throw(ArgumentError("Unknown LAPACK eigenvalue algorithm $alg"))
93+
end
94+
return D
95+
end
96+
97+
# for eigh_trunc!, it doesn't make sense to preallocate D and V as we don't know their sizes
6798
function eigh_trunc!(A::AbstractMatrix,
6899
backend::LAPACKBackend;
69100
alg=RobustRepresentations(),
70-
tol=zero(real(eltype(A))),
71-
rank=min(size(A)...),
101+
atol=zero(real(eltype(A))),
102+
rtol=zero(real(eltype(A))),
103+
rank=size(A, 1),
72104
kwargs...)
73105
if alg == RobustRepresentations()
74106
D, V = YALAPACK.heevr!(A; kwargs...)
@@ -77,12 +109,15 @@ function eigh_trunc!(A::AbstractMatrix,
77109
elseif alg == LinearAlgebra.QRIteration()
78110
D, V = YALAPACK.heev!(A; kwargs...)
79111
else
80-
throw(ArgumentError("Unknown algorithm $alg"))
112+
throw(ArgumentError("Unknown LAPACK eigenvalue algorithm $alg"))
81113
end
82-
# eigenvalues are sorted in ascending order; do we assume that they are positive?
114+
# eigenvalues are sorted in ascending order
115+
# TODO: do we assume that they are positive, or should we check for this?
116+
# or do we want to truncate based on absolute value and thus sort differently?
83117
n = length(D)
84-
s = max(n - rank, findfirst(>=(tol * D[end]), S))
118+
tol = convert(eltype(D), max(atol, rtol * D[n]))
119+
s = max(n - rank + 1, findfirst(>=(tol), D))
85120
# TODO: do we want views here, such that we do not need extra allocations if we later
86121
# copy them into other storage
87122
return D[n:-1:s], V[:, n:-1:s]
88-
end
123+
end

src/matrixfunctions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

src/qr.jl

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ function qr_full!(A::AbstractMatrix,
22
Q::AbstractMatrix=similar(A, (size(A, 1), size(A, 1))),
33
R::AbstractMatrix=similar(A, (size(A, 1), size(A, 2)));
44
kwargs...)
5-
return qr_full!(A, Q, R, default_backend(qr_full!, A; kwargs...))
5+
return qr_full!(A, Q, R, default_backend(qr_full!, A; kwargs...); kwargs...)
66
end
77
function qr_compact!(A::AbstractMatrix,
88
Q::AbstractMatrix=similar(A, (size(A, 1), size(A, 1))),
99
R::AbstractMatrix=similar(A, (size(A, 1), size(A, 2)));
1010
kwargs...)
11-
return qr_compact!(A, Q, R, default_backend(qr_compact!, A; kwargs...))
11+
return qr_compact!(A, Q, R, default_backend(qr_compact!, A; kwargs...); kwargs...)
1212
end
1313

1414
function default_backend(::typeof(qr_full!), A::AbstractMatrix; kwargs...)
@@ -20,4 +20,113 @@ end
2020

2121
function default_qr_backend(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
2222
return LAPACKBackend()
23-
end
23+
end
24+
25+
function check_qr_full_input(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix)
26+
m, n = size(A)
27+
size(Q) == (m, m) ||
28+
throw(DimensionMismatch("Full unitary matrix `Q` must be square with equal number of rows as A"))
29+
isempty(R) || size(R) == (m, n) ||
30+
throw(DimensionMismatch("Upper triangular matrix `R` must have size equal to A"))
31+
return nothing
32+
end
33+
function check_qr_compact_input(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix)
34+
m, n = size(A)
35+
if n <= m
36+
size(Q) == (m, n) ||
37+
throw(DimensionMismatch("Isometric `Q` must have size equal to A"))
38+
isempty(R) || size(R) == (n, n) ||
39+
throw(DimensionMismatch("Upper triangular matrix `R` must be square with equal number of columns as A"))
40+
else
41+
check_qr_full_input(A, Q, R)
42+
end
43+
end
44+
45+
function qr_full!(A::AbstractMatrix,
46+
Q::AbstractMatrix,
47+
R::AbstractMatrix,
48+
backend::LAPACKBackend;
49+
positive=false,
50+
pivoted=false,
51+
blocksize=((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A)))
52+
check_qr_full_input(A, Q, R)
53+
_unsafe_qr!(A, Q, R; positive=positive, pivoted=pivoted, blocksize=blocksize)
54+
return Q, R
55+
end
56+
57+
function qr_compact!(A::AbstractMatrix,
58+
Q::AbstractMatrix,
59+
R::AbstractMatrix,
60+
backend::LAPACKBackend;
61+
positive=false,
62+
pivoted=false,
63+
blocksize=((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A)))
64+
check_qr_compact_input(A, Q, R)
65+
_unsafe_qr!(A, Q, R; positive=positive, pivoted=pivoted, blocksize=blocksize)
66+
return Q, R
67+
end
68+
69+
function _unsafe_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
70+
positive=false,
71+
pivoted=false,
72+
blocksize=((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A)))
73+
m, n = size(A)
74+
minmn = min(m, n)
75+
computeR = length(R) > 0
76+
inplaceQ = Q === A
77+
78+
if pivoted && (blocksize > 1)
79+
throw(ArgumentError("LAPACK does not provide a blocked implementation for a pivoted QR decomposition"))
80+
end
81+
if inplaceQ && (computeR || positive || blocksize > 1 || m < n)
82+
throw(ArgumentError("inplace Q only supported if matrix is tall (`m >= n`), R is not required, and using the unblocked algorithm (`blocksize=1`) with `positive=false`"))
83+
end
84+
85+
if blocksize > 1
86+
nb = min(minmn, blocksize)
87+
if computeR # first use R as space for T
88+
A, T = YALAPACK.geqrt!(A, view(R, 1:nb, 1:minmn))
89+
else
90+
A, T = YALAPACK.geqrt!(A, similar(A, nb, minmn))
91+
end
92+
Q = YALAPACK.gemqrt!('L', 'N', A, T, one!(Q))
93+
else
94+
if pivoted
95+
A, τ, jpvt = YALAPACK.geqp3!(A)
96+
else
97+
A, τ = YALAPACK.geqrf!(A)
98+
end
99+
if inplaceQ
100+
Q = YALAPACK.orgqr!(A, τ)
101+
else
102+
Q = YALAPACK.ormqr!('L', 'N', A, τ, one!(Q))
103+
end
104+
end
105+
106+
if positive # already fix Q even if we do not need R
107+
@inbounds for j in 1:minmn
108+
s = safesign(A[j, j])
109+
@simd for i in 1:m
110+
Q[i, j] *= s
111+
end
112+
end
113+
end
114+
115+
if computeR
116+
= triu!(view(A, axes(R)...))
117+
if positive
118+
@inbounds for j in n:-1:1
119+
@simd for i in 1:min(minmn, j)
120+
R̃[i, j] = R̃[i, j] * conj(safesign(R̃[i, i]))
121+
end
122+
end
123+
end
124+
if !pivoted
125+
copyto!(R, R̃)
126+
else
127+
# probably very inefficient in terms of memory access
128+
copyto!(view(R, :, jpvt), R̃)
129+
end
130+
end
131+
return Q, R
132+
end

src/svd.jl

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# TODO: do not export but mark as public ?
2+
function svd!(A::AbstractMatrix, args...; kwargs...)
3+
return svd_compact!(A, args...; kwargs...)
4+
end
5+
16
function svd_full!(A::AbstractMatrix,
27
U::AbstractMatrix=similar(A, (size(A, 1), size(A, 1))),
38
S::AbstractVector=similar(A, real(eltype(A)), (min(size(A)...),)),
@@ -12,6 +17,12 @@ function svd_compact!(A::AbstractMatrix,
1217
kwargs...)
1318
return svd_compact!(A, U, S, Vᴴ, default_backend(svd_compact!, A; kwargs...); kwargs...)
1419
end
20+
function svd_vals!(A::AbstractMatrix,
21+
S::AbstractVector=similar(A, real(eltype(A)), (min(size(A)...),));
22+
kwargs...)
23+
return svd_vals!(A, S, default_backend(svd_vals!, A; kwargs...); kwargs...)
24+
end
25+
1526
function svd_trunc!(A::AbstractMatrix;
1627
kwargs...)
1728
return svd_trunc!(A, default_backend(svd_trunc!, A; kwargs...); kwargs...)
@@ -23,6 +34,9 @@ end
2334
function default_backend(::typeof(svd_compact!), A::AbstractMatrix; kwargs...)
2435
return default_svd_backend(A; kwargs...)
2536
end
37+
function default_backend(::typeof(svd_vals!), A::AbstractMatrix; kwargs...)
38+
return default_svd_backend(A; kwargs...)
39+
end
2640
function default_backend(::typeof(svd_trunc!), A::AbstractMatrix; kwargs...)
2741
return default_svd_backend(A; kwargs...)
2842
end
@@ -53,6 +67,13 @@ function check_svd_compact_input(A, U, S, Vᴴ)
5367
throw(DimensionMismatch("`svd_compact!` requires vector S of length min(size(A)..."))
5468
return nothing
5569
end
70+
function check_svd_vals_input(A, S)
71+
m, n = size(A)
72+
minmn = min(m, n)
73+
size(S) == (minmn,) ||
74+
throw(DimensionMismatch("`svd_vals!` requires vector S of length min(size(A)..."))
75+
return nothing
76+
end
5677

5778
function svd_full!(A::AbstractMatrix,
5879
U::AbstractMatrix,
@@ -66,7 +87,7 @@ function svd_full!(A::AbstractMatrix,
6687
elseif alg == LinearAlgebra.QRIteration()
6788
YALAPACK.gesvd!(A, S, U, Vᴴ)
6889
else
69-
throw(ArgumentError("Unknown algorithm $alg"))
90+
throw(ArgumentError("Unknown LAPACK singular value algorithm $alg"))
7091
end
7192
return U, S, Vᴴ
7293
end
@@ -82,26 +103,44 @@ function svd_compact!(A::AbstractMatrix,
82103
elseif alg == LinearAlgebra.QRIteration()
83104
YALAPACK.gesvd!(A, S, U, Vᴴ)
84105
else
85-
throw(ArgumentError("Unknown algorithm $alg"))
106+
throw(ArgumentError("Unknown LAPACK singular value algorithm $alg"))
86107
end
87108
return U, S, Vᴴ
88109
end
89110

111+
function svd_vals!(A::AbstractMatrix,
112+
S::AbstractVector,
113+
backend::LAPACKBackend;
114+
alg=LinearAlgebra.DivideAndConquer())
115+
check_svd_vals_input(A, S)
116+
m, n = size(A)
117+
if alg == LinearAlgebra.DivideAndConquer()
118+
YALAPACK.gesdd!(A, S, similar(A, m, 0), similar(A, n, 0))
119+
elseif alg == LinearAlgebra.QRIteration()
120+
YALAPACK.gesvd!(A, S, similar(A, m, 0), similar(A, n, 0))
121+
else
122+
throw(ArgumentError("Unknown LAPACK singular value algorithm $alg"))
123+
end
124+
return S
125+
end
126+
90127
# for svd_trunc!, it doesn't make sense to preallocate U, S, Vᴴ as we don't know their sizes
91128
function svd_trunc!(A::AbstractMatrix,
92129
backend::LAPACKBackend;
93130
alg=LinearAlgebra.DivideAndConquer(),
94-
tol=zero(real(eltype(A))),
131+
atol=zero(real(eltype(A))),
132+
rtol=zero(real(eltype(A))),
95133
rank=min(size(A)...))
96134
if alg == LinearAlgebra.DivideAndConquer()
97135
S, U, Vᴴ = YALAPACK.gesdd!(A)
98136
elseif alg == LinearAlgebra.QRIteration()
99137
S, U, Vᴴ = YALAPACK.gesvd!(A)
100138
else
101-
throw(ArgumentError("Unknown algorithm $alg"))
139+
throw(ArgumentError("Unknown LAPACK singular value algorithm $alg"))
102140
end
103-
r = min(rank, findlast(>=(tol * S[1]), S))
141+
tol = convert(eltype(S), max(atol, rtol * S[1]))
142+
r = min(rank, findlast(>=(tol), S))
104143
# TODO: do we want views here, such that we do not need extra allocations if we later
105144
# copy them into other storage
106145
return U[:, 1:r], S[1:r], Vᴴ[1:r, :]
107-
end
146+
end

0 commit comments

Comments
 (0)