Skip to content

Commit 5d36c98

Browse files
sanderdemeyerlkdvosJutho
authored
add support for BigFloats via a new extension (#87)
* add GenericExt Add extension using GenericLinearAlgebra and GenericSchur to be able to deal with `BigFloat`s * include weak dependencies * Update ext/MatrixAlgebraKitGenericExt.jl Co-authored-by: Lukas Devos <[email protected]> * Update ext/MatrixAlgebraKitGenericExt.jl Co-authored-by: Lukas Devos <[email protected]> * comments from Lukas * remove `BigFloat_LQ_Householder` * Change struct names and relax type restrictions * Split extensions * fix copies * fix type instability use Hermitian() small cleanup * Name change From GLA to GS, since eig using GenericSchur and not GenericLinearAlgebra * Update ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl Co-authored-by: Jutho <[email protected]> * Update ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl Co-authored-by: Jutho <[email protected]> * Update ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl Co-authored-by: Jutho <[email protected]> * Resolve comments Name change some copying issues resolved * change folder names * add docs and namechange GS and GLA algorithms are now defined in the docs `GLA_QR_Householder` is now `GLA_HouseholderQR` to be consistent with the LAPACK algorithms * Remove unnecessary type parameters * simplify SVD and remove allocations * simplify Eigh and remove allocations * simplify QR * simplify Eig and remove allocations * switch to `lmul!` * docstring improvements * move Diagonal tests --------- Co-authored-by: Lukas Devos <[email protected]> Co-authored-by: Jutho <[email protected]>
1 parent c52614b commit 5d36c98

File tree

16 files changed

+819
-17
lines changed

16 files changed

+819
-17
lines changed

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,24 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
1111
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1212
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
13+
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"
14+
GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e"
1315

1416
[extensions]
1517
MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore"
1618
MatrixAlgebraKitAMDGPUExt = "AMDGPU"
1719
MatrixAlgebraKitCUDAExt = "CUDA"
20+
MatrixAlgebraKitGenericLinearAlgebraExt = "GenericLinearAlgebra"
21+
MatrixAlgebraKitGenericSchurExt = "GenericSchur"
1822

1923
[compat]
2024
AMDGPU = "2"
2125
Aqua = "0.6, 0.7, 0.8"
2226
ChainRulesCore = "1"
2327
ChainRulesTestUtils = "1"
2428
CUDA = "5"
29+
GenericLinearAlgebra = "0.3.19"
30+
GenericSchur = "0.5.6"
2531
JET = "0.9, 0.10"
2632
LinearAlgebra = "1"
2733
SafeTestsets = "0.1"
@@ -42,4 +48,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
4248
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4349

4450
[targets]
45-
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU"]
51+
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur"]
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
module MatrixAlgebraKitGenericLinearAlgebraExt
2+
3+
using MatrixAlgebraKit
4+
using MatrixAlgebraKit: sign_safe, check_input, diagview
5+
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
6+
using LinearAlgebra: I, Diagonal, lmul!
7+
8+
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
9+
return GLA_QRIteration()
10+
end
11+
12+
for f! in (:svd_compact!, :svd_full!, :svd_vals!)
13+
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = nothing
14+
end
15+
16+
function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, ::GLA_QRIteration)
17+
F = svd!(A)
18+
U, S, Vᴴ = F.U, Diagonal(F.S), F.Vt
19+
return MatrixAlgebraKit.gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...)
20+
end
21+
22+
function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, ::GLA_QRIteration)
23+
F = svd!(A; full = true)
24+
U, Vᴴ = F.U, F.Vt
25+
S = MatrixAlgebraKit.zero!(similar(F.S, (size(U, 2), size(Vᴴ, 1))))
26+
diagview(S) .= F.S
27+
return MatrixAlgebraKit.gaugefix!(svd_full!, U, S, Vᴴ, size(A)...)
28+
end
29+
30+
function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, ::GLA_QRIteration)
31+
return svdvals!(A)
32+
end
33+
34+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
35+
return GLA_QRIteration(; kwargs...)
36+
end
37+
38+
for f! in (:eigh_full!, :eigh_vals!)
39+
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = nothing
40+
end
41+
42+
function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix, DV, ::GLA_QRIteration)
43+
eigval, eigvec = eigen!(Hermitian(A); sortby = real)
44+
return Diagonal(eigval::AbstractVector{real(eltype(A))}), eigvec::AbstractMatrix{eltype(A)}
45+
end
46+
47+
function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration)
48+
return eigvals!(Hermitian(A); sortby = real)
49+
end
50+
51+
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
52+
return GLA_HouseholderQR(; kwargs...)
53+
end
54+
55+
function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR)
56+
check_input(qr_full!, A, QR, alg)
57+
Q, R = QR
58+
return _gla_householder_qr!(A, Q, R; alg.kwargs...)
59+
end
60+
61+
function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR)
62+
check_input(qr_compact!, A, QR, alg)
63+
Q, R = QR
64+
return _gla_householder_qr!(A, Q, R; alg.kwargs...)
65+
end
66+
67+
function _gla_householder_qr!(A::AbstractMatrix, Q, R; positive = false, blocksize = 1, pivoted = false)
68+
pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_HouseholderQR."))
69+
(blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_HouseholderQR."))
70+
71+
m, n = size(A)
72+
k = min(m, n)
73+
Q̃, R̃ = qr!(A)
74+
lmul!(Q̃, MatrixAlgebraKit.one!(Q))
75+
76+
if positive
77+
@inbounds for j in 1:k
78+
s = sign_safe(R̃[j, j])
79+
@simd for i in 1:m
80+
Q[i, j] *= s
81+
end
82+
end
83+
end
84+
85+
computeR = length(R) > 0
86+
if computeR
87+
if positive
88+
@inbounds for j in n:-1:1
89+
@simd for i in 1:min(k, j)
90+
R[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i]))
91+
end
92+
@simd for i in (min(k, j) + 1):size(R, 1)
93+
R[i, j] = zero(eltype(R))
94+
end
95+
end
96+
else
97+
R[1:k, :] .=
98+
MatrixAlgebraKit.zero!(@view(R[(k + 1):end, :]))
99+
end
100+
end
101+
return Q, R
102+
end
103+
104+
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
105+
return MatrixAlgebraKit.LQViaTransposedQR(GLA_HouseholderQR(; kwargs...))
106+
end
107+
108+
end
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
module MatrixAlgebraKitGenericSchurExt
2+
3+
using MatrixAlgebraKit
4+
using MatrixAlgebraKit: check_input
5+
using LinearAlgebra: Diagonal
6+
using GenericSchur
7+
8+
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
9+
return GS_QRIteration(; kwargs...)
10+
end
11+
12+
for f! in (:eig_full!, :eig_vals!)
13+
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GS_QRIteration) = nothing
14+
end
15+
16+
function MatrixAlgebraKit.eig_full!(A::AbstractMatrix, DV, ::GS_QRIteration)
17+
D, V = GenericSchur.eigen!(A)
18+
return Diagonal(D), V
19+
end
20+
21+
function MatrixAlgebraKit.eig_vals!(A::AbstractMatrix, D, ::GS_QRIteration)
22+
return GenericSchur.eigvals!(A)
23+
end
24+
25+
end

src/MatrixAlgebraKit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ export left_orth!, right_orth!, left_null!, right_null!
3333
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
3434
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
3535
LAPACK_DivideAndConquer, LAPACK_Jacobi
36+
export GLA_HouseholderQR, GLA_QRIteration, GS_QRIteration
3637
export LQViaTransposedQR
3738
export PolarViaSVD, PolarNewton
3839
export DiagonalAlgorithm

src/interface/decompositions.jl

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Algorithm type to denote the standard LAPACK algorithm for computing the QR deco
1616
a matrix using Householder reflectors. The specific LAPACK function can be controlled using
1717
the keyword arugments, i.e. `?geqrt` will be chosen if `blocksize > 1`. With
1818
`blocksize == 1`, `?geqrf` will be chosen if `pivoted == false` and `?geqp3` will be chosen
19-
if `pivoted == true`. The keyword `positive=true` can be used to ensure that the diagonal
19+
if `pivoted == true`. The keyword `positive = true` can be used to ensure that the diagonal
2020
elements of `R` are non-negative.
2121
"""
2222
@algdef LAPACK_HouseholderQR
@@ -27,11 +27,21 @@ elements of `R` are non-negative.
2727
Algorithm type to denote the standard LAPACK algorithm for computing the LQ decomposition of
2828
a matrix using Householder reflectors. The specific LAPACK function can be controlled using
2929
the keyword arugments, i.e. `?gelqt` will be chosen if `blocksize > 1` or `?gelqf` will be
30-
chosen if `blocksize == 1`. The keyword `positive=true` can be used to ensure that the diagonal
30+
chosen if `blocksize == 1`. The keyword `positive = true` can be used to ensure that the diagonal
3131
elements of `L` are non-negative.
3232
"""
3333
@algdef LAPACK_HouseholderLQ
3434

35+
"""
36+
GLA_HouseholderQR(; positive = false)
37+
38+
Algorithm type to denote the GenericLinearAlgebra.jl implementation for computing the QR decomposition
39+
of a matrix using Householder reflectors. Currently, only `blocksize = 1` and `pivoted == false`
40+
are supported. The keyword `positive = true` can be used to ensure that the diagonal elements
41+
of `R` are non-negative.
42+
"""
43+
@algdef GLA_HouseholderQR
44+
3545
# TODO:
3646
@algdef LAPACK_HouseholderQL
3747
@algdef LAPACK_HouseholderRQ
@@ -56,6 +66,14 @@ eigenvalue decomposition of a matrix.
5666

5767
const LAPACK_EigAlgorithm = Union{LAPACK_Simple, LAPACK_Expert}
5868

69+
"""
70+
GS_QRIteration()
71+
72+
Algorithm type to denote the GenericSchur.jl implementation for computing the
73+
eigenvalue decomposition of a non-Hermitian matrix.
74+
"""
75+
@algdef GS_QRIteration
76+
5977
# Hermitian Eigenvalue Decomposition
6078
# ----------------------------------
6179
"""
@@ -100,6 +118,15 @@ const LAPACK_EighAlgorithm = Union{
100118
LAPACK_MultipleRelativelyRobustRepresentations,
101119
}
102120

121+
"""
122+
GLA_QRIteration()
123+
124+
Algorithm type to denote the GenericLinearAlgebra.jl implementation for computing the
125+
eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of
126+
a general matrix.
127+
"""
128+
@algdef GLA_QRIteration
129+
103130
# Singular Value Decomposition
104131
# ----------------------------
105132
"""

test/eig.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ using StableRNGs
55
using LinearAlgebra: Diagonal
66
using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm
77

8-
const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
8+
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
9+
GenericFloats = (Float16, BigFloat, Complex{BigFloat})
910

1011
@testset "eig_full! for T = $T" for T in BLASFloats
1112
rng = StableRNG(123)
@@ -91,7 +92,7 @@ end
9192
@test ϵ3 norm(diagview(D)[3:4]) atol = atol
9293
end
9394

94-
@testset "eig for Diagonal{$T}" for T in BLASFloats
95+
@testset "eig for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...)
9596
rng = StableRNG(123)
9697
m = 54
9798
Ad = randn(rng, T, m)

test/eigh.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ using StableRNGs
55
using LinearAlgebra: LinearAlgebra, Diagonal, I
66
using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm
77

8-
const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
8+
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
9+
GenericFloats = (Float16, BigFloat, Complex{BigFloat})
910

1011
@testset "eigh_full! for T = $T" for T in BLASFloats
1112
rng = StableRNG(123)
@@ -100,7 +101,7 @@ end
100101
@test ϵ3 norm(diagview(D)[3:4]) atol = atol
101102
end
102103

103-
@testset "eigh for Diagonal{$T}" for T in BLASFloats
104+
@testset "eigh for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...)
104105
rng = StableRNG(123)
105106
m = 54
106107
Ad = randn(rng, T, m)

test/genericlinearalgebra/eigh.jl

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using TestExtras
4+
using StableRNGs
5+
using LinearAlgebra: LinearAlgebra, Diagonal, I
6+
using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm
7+
using GenericLinearAlgebra
8+
9+
const eltypes = (BigFloat, Complex{BigFloat})
10+
11+
@testset "eigh_full! for T = $T" for T in eltypes
12+
rng = StableRNG(123)
13+
m = 54
14+
alg = GLA_QRIteration()
15+
16+
A = randn(rng, T, m, m)
17+
A = (A + A') / 2
18+
19+
D, V = @constinferred eigh_full(A; alg)
20+
@test A * V V * D
21+
@test isunitary(V)
22+
@test all(isreal, D)
23+
24+
D2, V2 = eigh_full!(copy(A), (D, V), alg)
25+
@test D2 D
26+
@test V2 V
27+
28+
D3 = @constinferred eigh_vals(A, alg)
29+
@test D Diagonal(D3)
30+
end
31+
32+
@testset "eigh_trunc! for T = $T" for T in eltypes
33+
rng = StableRNG(123)
34+
m = 54
35+
alg = GLA_QRIteration()
36+
A = randn(rng, T, m, m)
37+
A = A * A'
38+
A = (A + A') / 2
39+
Ac = similar(A)
40+
D₀ = reverse(eigh_vals(A))
41+
42+
r = m - 2
43+
s = 1 + sqrt(eps(real(T)))
44+
atol = sqrt(eps(real(T)))
45+
46+
D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc = truncrank(r))
47+
Dfull, Vfull = eigh_full(A; alg)
48+
@test length(diagview(D1)) == r
49+
@test isisometric(V1)
50+
@test A * V1 V1 * D1
51+
@test LinearAlgebra.opnorm(A - V1 * D1 * V1') D₀[r + 1]
52+
@test ϵ1 norm(view(D₀, (r + 1):m)) atol = atol
53+
54+
trunc = trunctol(; atol = s * D₀[r + 1])
55+
D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc)
56+
@test length(diagview(D2)) == r
57+
@test isisometric(V2)
58+
@test A * V2 V2 * D2
59+
@test ϵ2 norm(view(D₀, (r + 1):m)) atol = atol
60+
61+
s = 1 - sqrt(eps(real(T)))
62+
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
63+
D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc)
64+
@test length(diagview(D3)) == r
65+
@test A * V3 V3 * D3
66+
@test ϵ3 norm(view(D₀, (r + 1):m)) atol = atol
67+
68+
# test for same subspace
69+
@test V1 * (V1' * V2) V2
70+
@test V2 * (V2' * V1) V1
71+
@test V1 * (V1' * V3) V3
72+
@test V3 * (V3' * V1) V1
73+
end
74+
75+
@testset "eigh_trunc! specify truncation algorithm T = $T" for T in eltypes
76+
rng = StableRNG(123)
77+
m = 4
78+
atol = sqrt(eps(real(T)))
79+
V = qr_compact(randn(rng, T, m, m))[1]
80+
D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01])
81+
A = V * D * V'
82+
A = (A + A') / 2
83+
alg = TruncatedAlgorithm(GLA_QRIteration(), truncrank(2))
84+
D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg)
85+
@test diagview(D2) diagview(D)[1:2]
86+
@test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2))
87+
@test ϵ2 norm(diagview(D)[3:4]) atol = atol
88+
89+
alg = TruncatedAlgorithm(GLA_QRIteration(), truncerror(; atol = 0.2))
90+
D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg)
91+
@test diagview(D3) diagview(D)[1:2]
92+
@test ϵ3 norm(diagview(D)[3:4]) atol = atol
93+
end

0 commit comments

Comments
 (0)