-
Notifications
You must be signed in to change notification settings - Fork 5
add support for BigFloats via a new extension #87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
93f37f5
5e08f7a
92534bc
ddaf055
88b9902
3eee758
f8d331e
03caa56
faa0921
9ed02bd
c3e069e
7af4d52
c6ba4e4
fdb1222
49404bf
837a7a4
78f92f7
fb1994f
3a36446
e6f64b2
4323b17
f99112b
89fd42b
35af44d
4fe0fe1
c08dad4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| module MatrixAlgebraKitGenericLinearAlgebraExt | ||
|
|
||
| using MatrixAlgebraKit | ||
| using MatrixAlgebraKit: sign_safe, check_input | ||
| using GenericLinearAlgebra: svd, svdvals!, eigen, eigvals, Hermitian, qr | ||
| using LinearAlgebra: I, Diagonal | ||
|
|
||
| function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} | ||
| return GLA_svd_QRIteration() | ||
| end | ||
|
|
||
| function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix{T}, USVᴴ, alg::GLA_svd_QRIteration) where {T} | ||
| check_input(svd_compact!, A, USVᴴ, alg) | ||
| U, S, Vᴴ = USVᴴ | ||
| Ũ, S̃, Ṽ = svd(A) | ||
sanderdemeyer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| copyto!(U, Ũ) | ||
| copyto!(S, Diagonal(S̃)) | ||
| copyto!(Vᴴ, Ṽ') # conjugation to account for difference in convention | ||
sanderdemeyer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return U, S, Vᴴ | ||
| end | ||
|
|
||
| function MatrixAlgebraKit.svd_full!(A::AbstractMatrix{T}, USVᴴ, alg::GLA_svd_QRIteration) where {T} | ||
sanderdemeyer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| check_input(svd_full!, A, USVᴴ, alg) | ||
| U, S, Vᴴ = USVᴴ | ||
| m, n = size(A) | ||
| minmn = min(m, n) | ||
| if minmn == 0 | ||
| MatrixAlgebraKit.one!(U) | ||
| MatrixAlgebraKit.zero!(S) | ||
| MatrixAlgebraKit.one!(Vᴴ) | ||
| return USVᴴ | ||
| end | ||
| S̃ = fill!(S, zero(T)) | ||
sanderdemeyer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| U_compact, S_compact, V_compact = svd(A) | ||
sanderdemeyer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| S̃[1:minmn, 1:minmn] .= Diagonal(S_compact) | ||
| copyto!(S, S̃) | ||
sanderdemeyer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| U = _gram_schmidt!(U, U_compact) | ||
| Vᴴ = _gram_schmidt!(Vᴴ, V_compact; adjoint = true) | ||
|
|
||
| return MatrixAlgebraKit.gaugefix!(svd_full!, U, S, Vᴴ, m, n) | ||
| end | ||
|
|
||
| function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix{T}, S, alg::GLA_svd_QRIteration) where {T} | ||
| check_input(svd_vals!, A, S, alg) | ||
| S̃ = svdvals!(A) | ||
| copyto!(S, S̃) | ||
| return S | ||
| end | ||
|
|
||
| function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} | ||
| return GLA_eigh_Francis(; kwargs...) | ||
| end | ||
|
|
||
| function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix{T}, DV, alg::GLA_eigh_Francis) where {T} | ||
| check_input(eigh_full!, A, DV, alg) | ||
| D, V = DV | ||
| eigval, eigvec = eigen(Hermitian(A); sortby = real) | ||
sanderdemeyer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| copyto!(D, Diagonal(eigval)) | ||
| copyto!(V, eigvec) | ||
| return D, V | ||
| end | ||
|
|
||
| function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix{T}, D, alg::GLA_eigh_Francis) where {T} | ||
| check_input(eigh_vals!, A, D, alg) | ||
| D = eigvals(A; sortby = real) | ||
| return real.(D) | ||
sanderdemeyer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| end | ||
|
|
||
| function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} | ||
| return GLA_QR_Householder(; kwargs...) | ||
| end | ||
|
|
||
| function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::GLA_QR_Householder) | ||
| Q, R = QR | ||
| m, n = size(A) | ||
| minmn = min(m, n) | ||
| computeR = length(R) > 0 | ||
|
|
||
| Q_zero = zeros(eltype(Q), (m, minmn)) | ||
| R_zero = zeros(eltype(R), (minmn, n)) | ||
| Q_compact, R_compact = _gla_householder_qr!(A, Q_zero, R_zero; alg.kwargs...) | ||
| Q = _gram_schmidt!(Q, Q_compact[:, 1:min(m, n)]) | ||
sanderdemeyer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if computeR | ||
| R = fill!(R, zero(eltype(R))) | ||
| R[1:minmn, 1:n] .= R_compact | ||
sanderdemeyer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| end | ||
| return Q, R | ||
| end | ||
|
|
||
| function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::GLA_QR_Householder) | ||
| check_input(qr_compact!, A, QR, alg) | ||
| Q, R = QR | ||
| Q, R = _gla_householder_qr!(A, Q, R; alg.kwargs...) | ||
| return Q, R | ||
| end | ||
|
|
||
| function _gla_householder_qr!(A::AbstractMatrix{T}, Q, R; positive = false, blocksize = 1, pivoted = false) where {T} | ||
| pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_QR_Householder.")) | ||
| (blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_QR_Householder.")) | ||
|
|
||
| m, n = size(A) | ||
| k = min(m, n) | ||
| computeR = length(R) > 0 | ||
| Q̃, R̃ = qr(A) | ||
|
||
| Q̃ = convert(Array, Q̃) | ||
sanderdemeyer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if positive | ||
| @inbounds for j in 1:k | ||
| s = sign_safe(R̃[j, j]) | ||
| @simd for i in 1:m | ||
| Q̃[i, j] *= s | ||
| end | ||
| end | ||
| end | ||
| copyto!(Q, Q̃) | ||
| if computeR | ||
| if positive | ||
| @inbounds for j in n:-1:1 | ||
| @simd for i in 1:min(k, j) | ||
| R̃[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i])) | ||
| end | ||
| end | ||
| end | ||
| copyto!(R, R̃) | ||
| end | ||
| return Q, R | ||
| end | ||
|
|
||
| function _gram_schmidt(Q_compact) | ||
| m, minmn = size(Q_compact) | ||
| if minmn >= m | ||
| return Q_compact | ||
| end | ||
| Q = zeros(eltype(Q_compact), (m, m)) | ||
| Q[:, 1:minmn] .= Q_compact | ||
| for j in (minmn + 1):m | ||
| v = rand(eltype(Q), m) | ||
| for i in 1:(j - 1) | ||
| r = sum([v[k] * conj(Q[k, i])] for k in 1:size(v)[1])[1] | ||
| v .= v .- r * Q[:, i] | ||
| end | ||
| Q[:, j] = v ./ MatrixAlgebraKit.norm(v) | ||
| end | ||
| return Q | ||
| end | ||
|
|
||
| function _gram_schmidt!(Q, Q_compact; adjoint = false) | ||
| Q̃ = _gram_schmidt(Q_compact) | ||
| if adjoint | ||
| copyto!(Q, Q̃') | ||
| else | ||
| copyto!(Q, Q̃) | ||
| end | ||
| return Q | ||
| end | ||
sanderdemeyer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} | ||
| return MatrixAlgebraKit.LQViaTransposedQR(GLA_QR_Householder(; kwargs...)) | ||
| end | ||
|
|
||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| module MatrixAlgebraKitGenericSchurExt | ||
|
|
||
| using MatrixAlgebraKit | ||
| using MatrixAlgebraKit: check_input | ||
| using LinearAlgebra: Diagonal | ||
| using GenericSchur | ||
|
|
||
| function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} | ||
| return GS_eig_Francis(; kwargs...) | ||
| end | ||
|
|
||
| function MatrixAlgebraKit.eig_full!(A::AbstractMatrix{T}, DV, alg::GS_eig_Francis) where {T} | ||
| check_input(eig_full!, A, DV, alg) | ||
| D, V = DV | ||
| D̃, Ṽ = GenericSchur.eigen!(A) | ||
| copyto!(D, Diagonal(D̃)) | ||
| copyto!(V, Ṽ) | ||
| return D, V | ||
| end | ||
|
|
||
| function MatrixAlgebraKit.eig_vals!(A::AbstractMatrix{T}, D, alg::GS_eig_Francis) where {T} | ||
| check_input(eig_vals!, A, D, alg) | ||
| eigval = GenericSchur.eigvals!(A) | ||
| copyto!(D, eigval) | ||
| return D | ||
| end | ||
|
|
||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| using MatrixAlgebraKit | ||
| using Test | ||
| using TestExtras | ||
| using StableRNGs | ||
| using LinearAlgebra: Diagonal | ||
| using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm | ||
| using GenericSchur | ||
|
|
||
| const eltypes = (BigFloat, Complex{BigFloat}) | ||
|
|
||
| @testset "eig_full! for T = $T" for T in eltypes | ||
| rng = StableRNG(123) | ||
| m = 24 | ||
| alg = GS_eig_Francis() | ||
| A = randn(rng, T, m, m) | ||
| Tc = complex(T) | ||
|
|
||
| D, V = @constinferred eig_full(A; alg = ($alg)) | ||
| @test eltype(D) == eltype(V) == Tc | ||
| @test A * V ≈ V * D | ||
|
|
||
| alg′ = @constinferred MatrixAlgebraKit.select_algorithm(eig_full!, A, $alg) | ||
|
|
||
| Ac = similar(A) | ||
| D2, V2 = @constinferred eig_full!(copy!(Ac, A), (D, V), alg′) | ||
| @test D2 === D | ||
| @test V2 === V | ||
| @test A * V ≈ V * D | ||
|
|
||
| Dc = @constinferred eig_vals(A, alg′) | ||
| @test eltype(Dc) == Tc | ||
| @test D ≈ Diagonal(Dc) | ||
| end | ||
|
|
||
| @testset "eig_trunc! for T = $T" for T in eltypes | ||
| rng = StableRNG(123) | ||
| m = 6 | ||
| alg = GS_eig_Francis() | ||
| A = randn(rng, T, m, m) | ||
| A *= A' # TODO: deal with eigenvalue ordering etc | ||
| # eigenvalues are sorted by ascending real component... | ||
| D₀ = sort!(eig_vals(A); by = abs, rev = true) | ||
| rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2)) | ||
| r = length(D₀) - rmin | ||
| atol = sqrt(eps(real(T))) | ||
|
|
||
| D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r)) | ||
| D1base, V1base = @constinferred eig_full(A; alg) | ||
|
|
||
| @test length(diagview(D1)) == r | ||
| @test A * V1 ≈ V1 * D1 | ||
| @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol | ||
|
|
||
| s = 1 + sqrt(eps(real(T))) | ||
| trunc = trunctol(; atol = s * abs(D₀[r + 1])) | ||
| D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc) | ||
| @test length(diagview(D2)) == r | ||
| @test A * V2 ≈ V2 * D2 | ||
| @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol | ||
|
|
||
| s = 1 - sqrt(eps(real(T))) | ||
| trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) | ||
| D3, V3, ϵ3 = @constinferred eig_trunc(A; alg, trunc) | ||
| @test length(diagview(D3)) == r | ||
| @test A * V3 ≈ V3 * D3 | ||
| @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol | ||
|
|
||
| # trunctol keeps order, truncrank might not | ||
| # test for same subspace | ||
| @test V1 * ((V1' * V1) \ (V1' * V2)) ≈ V2 | ||
| @test V2 * ((V2' * V2) \ (V2' * V1)) ≈ V1 | ||
| @test V1 * ((V1' * V1) \ (V1' * V3)) ≈ V3 | ||
| @test V3 * ((V3' * V3) \ (V3' * V1)) ≈ V1 | ||
| end | ||
|
|
||
| @testset "eig_trunc! specify truncation algorithm T = $T" for T in eltypes | ||
| rng = StableRNG(123) | ||
| m = 4 | ||
| atol = sqrt(eps(real(T))) | ||
| V = randn(rng, T, m, m) | ||
| D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) | ||
| A = V * D * inv(V) | ||
| alg = TruncatedAlgorithm(GS_eig_Francis(), truncrank(2)) | ||
| D2, V2, ϵ2 = @constinferred eig_trunc(A; alg) | ||
| @test diagview(D2) ≈ diagview(D)[1:2] | ||
| @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol | ||
| @test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2)) | ||
|
|
||
| alg = TruncatedAlgorithm(GS_eig_Francis(), truncerror(; atol = 0.2, p = 1)) | ||
| D3, V3, ϵ3 = @constinferred eig_trunc(A; alg) | ||
| @test diagview(D3) ≈ diagview(D)[1:2] | ||
| @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol | ||
| end | ||
|
|
||
| @testset "eig for Diagonal{$T}" for T in eltypes | ||
| rng = StableRNG(123) | ||
| m = 24 | ||
| Ad = randn(rng, T, m) | ||
| A = Diagonal(Ad) | ||
| atol = sqrt(eps(real(T))) | ||
|
|
||
| D, V = @constinferred eig_full(A) | ||
| @test D isa Diagonal{T} && size(D) == size(A) | ||
| @test V isa Diagonal{T} && size(V) == size(A) | ||
| @test A * V ≈ V * D | ||
|
|
||
| D2 = @constinferred eig_vals(A) | ||
| @test D2 isa AbstractVector{T} && length(D2) == m | ||
| @test diagview(D) ≈ D2 | ||
|
|
||
| A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) | ||
| alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) | ||
| D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg) | ||
| @test diagview(D2) ≈ diagview(A2)[1:2] | ||
| @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol | ||
| end |
Uh oh!
There was an error while loading. Please reload this page.