diff --git a/Project.toml b/Project.toml index be0bee32..934c0ceb 100644 --- a/Project.toml +++ b/Project.toml @@ -10,11 +10,15 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a" +GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e" [extensions] MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore" MatrixAlgebraKitAMDGPUExt = "AMDGPU" MatrixAlgebraKitCUDAExt = "CUDA" +MatrixAlgebraKitGenericLinearAlgebraExt = "GenericLinearAlgebra" +MatrixAlgebraKitGenericSchurExt = "GenericSchur" [compat] AMDGPU = "2" @@ -22,6 +26,8 @@ Aqua = "0.6, 0.7, 0.8" ChainRulesCore = "1" ChainRulesTestUtils = "1" CUDA = "5" +GenericLinearAlgebra = "0.3.19" +GenericSchur = "0.5.6" JET = "0.9, 0.10" LinearAlgebra = "1" SafeTestsets = "0.1" @@ -42,4 +48,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU"] +test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur"] diff --git a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl new file mode 100644 index 00000000..7b078381 --- /dev/null +++ b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl @@ -0,0 +1,108 @@ +module MatrixAlgebraKitGenericLinearAlgebraExt + +using MatrixAlgebraKit +using MatrixAlgebraKit: sign_safe, check_input, diagview +using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr! +using LinearAlgebra: I, Diagonal, lmul! + +function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} + return GLA_QRIteration() +end + +for f! in (:svd_compact!, :svd_full!, :svd_vals!) + @eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = nothing +end + +function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, ::GLA_QRIteration) + F = svd!(A) + U, S, Vᴴ = F.U, Diagonal(F.S), F.Vt + return MatrixAlgebraKit.gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...) +end + +function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, ::GLA_QRIteration) + F = svd!(A; full = true) + U, Vᴴ = F.U, F.Vt + S = MatrixAlgebraKit.zero!(similar(F.S, (size(U, 2), size(Vᴴ, 1)))) + diagview(S) .= F.S + return MatrixAlgebraKit.gaugefix!(svd_full!, U, S, Vᴴ, size(A)...) +end + +function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, ::GLA_QRIteration) + return svdvals!(A) +end + +function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} + return GLA_QRIteration(; kwargs...) +end + +for f! in (:eigh_full!, :eigh_vals!) + @eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = nothing +end + +function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix, DV, ::GLA_QRIteration) + eigval, eigvec = eigen!(Hermitian(A); sortby = real) + return Diagonal(eigval::AbstractVector{real(eltype(A))}), eigvec::AbstractMatrix{eltype(A)} +end + +function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration) + return eigvals!(Hermitian(A); sortby = real) +end + +function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} + return GLA_HouseholderQR(; kwargs...) +end + +function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR) + check_input(qr_full!, A, QR, alg) + Q, R = QR + return _gla_householder_qr!(A, Q, R; alg.kwargs...) +end + +function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR) + check_input(qr_compact!, A, QR, alg) + Q, R = QR + return _gla_householder_qr!(A, Q, R; alg.kwargs...) +end + +function _gla_householder_qr!(A::AbstractMatrix, Q, R; positive = false, blocksize = 1, pivoted = false) + pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_HouseholderQR.")) + (blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_HouseholderQR.")) + + m, n = size(A) + k = min(m, n) + Q̃, R̃ = qr!(A) + lmul!(Q̃, MatrixAlgebraKit.one!(Q)) + + if positive + @inbounds for j in 1:k + s = sign_safe(R̃[j, j]) + @simd for i in 1:m + Q[i, j] *= s + end + end + end + + computeR = length(R) > 0 + if computeR + if positive + @inbounds for j in n:-1:1 + @simd for i in 1:min(k, j) + R[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i])) + end + @simd for i in (min(k, j) + 1):size(R, 1) + R[i, j] = zero(eltype(R)) + end + end + else + R[1:k, :] .= R̃ + MatrixAlgebraKit.zero!(@view(R[(k + 1):end, :])) + end + end + return Q, R +end + +function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} + return MatrixAlgebraKit.LQViaTransposedQR(GLA_HouseholderQR(; kwargs...)) +end + +end diff --git a/ext/MatrixAlgebraKitGenericSchurExt.jl b/ext/MatrixAlgebraKitGenericSchurExt.jl new file mode 100644 index 00000000..d278b5c5 --- /dev/null +++ b/ext/MatrixAlgebraKitGenericSchurExt.jl @@ -0,0 +1,25 @@ +module MatrixAlgebraKitGenericSchurExt + +using MatrixAlgebraKit +using MatrixAlgebraKit: check_input +using LinearAlgebra: Diagonal +using GenericSchur + +function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} + return GS_QRIteration(; kwargs...) +end + +for f! in (:eig_full!, :eig_vals!) + @eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GS_QRIteration) = nothing +end + +function MatrixAlgebraKit.eig_full!(A::AbstractMatrix, DV, ::GS_QRIteration) + D, V = GenericSchur.eigen!(A) + return Diagonal(D), V +end + +function MatrixAlgebraKit.eig_vals!(A::AbstractMatrix, D, ::GS_QRIteration) + return GenericSchur.eigvals!(A) +end + +end diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 5211476f..4a846f85 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -33,6 +33,7 @@ export left_orth!, right_orth!, left_null!, right_null! export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations, LAPACK_DivideAndConquer, LAPACK_Jacobi +export GLA_HouseholderQR, GLA_QRIteration, GS_QRIteration export LQViaTransposedQR export PolarViaSVD, PolarNewton export DiagonalAlgorithm diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index c36a6b02..1bdf1534 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -16,7 +16,7 @@ Algorithm type to denote the standard LAPACK algorithm for computing the QR deco a matrix using Householder reflectors. The specific LAPACK function can be controlled using the keyword arugments, i.e. `?geqrt` will be chosen if `blocksize > 1`. With `blocksize == 1`, `?geqrf` will be chosen if `pivoted == false` and `?geqp3` will be chosen -if `pivoted == true`. The keyword `positive=true` can be used to ensure that the diagonal +if `pivoted == true`. The keyword `positive = true` can be used to ensure that the diagonal elements of `R` are non-negative. """ @algdef LAPACK_HouseholderQR @@ -27,11 +27,21 @@ elements of `R` are non-negative. Algorithm type to denote the standard LAPACK algorithm for computing the LQ decomposition of a matrix using Householder reflectors. The specific LAPACK function can be controlled using the keyword arugments, i.e. `?gelqt` will be chosen if `blocksize > 1` or `?gelqf` will be -chosen if `blocksize == 1`. The keyword `positive=true` can be used to ensure that the diagonal +chosen if `blocksize == 1`. The keyword `positive = true` can be used to ensure that the diagonal elements of `L` are non-negative. """ @algdef LAPACK_HouseholderLQ +""" + GLA_HouseholderQR(; positive = false) + +Algorithm type to denote the GenericLinearAlgebra.jl implementation for computing the QR decomposition +of a matrix using Householder reflectors. Currently, only `blocksize = 1` and `pivoted == false` +are supported. The keyword `positive = true` can be used to ensure that the diagonal elements +of `R` are non-negative. +""" +@algdef GLA_HouseholderQR + # TODO: @algdef LAPACK_HouseholderQL @algdef LAPACK_HouseholderRQ @@ -56,6 +66,14 @@ eigenvalue decomposition of a matrix. const LAPACK_EigAlgorithm = Union{LAPACK_Simple, LAPACK_Expert} +""" + GS_QRIteration() + +Algorithm type to denote the GenericSchur.jl implementation for computing the +eigenvalue decomposition of a non-Hermitian matrix. +""" +@algdef GS_QRIteration + # Hermitian Eigenvalue Decomposition # ---------------------------------- """ @@ -100,6 +118,15 @@ const LAPACK_EighAlgorithm = Union{ LAPACK_MultipleRelativelyRobustRepresentations, } +""" + GLA_QRIteration() + +Algorithm type to denote the GenericLinearAlgebra.jl implementation for computing the +eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of +a general matrix. +""" +@algdef GLA_QRIteration + # Singular Value Decomposition # ---------------------------- """ diff --git a/test/eig.jl b/test/eig.jl index 0ece1b35..6da6d72c 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -5,7 +5,8 @@ using StableRNGs using LinearAlgebra: Diagonal using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm -const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +GenericFloats = (Float16, BigFloat, Complex{BigFloat}) @testset "eig_full! for T = $T" for T in BLASFloats rng = StableRNG(123) @@ -91,7 +92,7 @@ end @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol end -@testset "eig for Diagonal{$T}" for T in BLASFloats +@testset "eig for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) rng = StableRNG(123) m = 54 Ad = randn(rng, T, m) diff --git a/test/eigh.jl b/test/eigh.jl index bc04b057..92b0f3a0 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -5,7 +5,8 @@ using StableRNGs using LinearAlgebra: LinearAlgebra, Diagonal, I using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm -const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +GenericFloats = (Float16, BigFloat, Complex{BigFloat}) @testset "eigh_full! for T = $T" for T in BLASFloats rng = StableRNG(123) @@ -100,7 +101,7 @@ end @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol end -@testset "eigh for Diagonal{$T}" for T in BLASFloats +@testset "eigh for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) rng = StableRNG(123) m = 54 Ad = randn(rng, T, m) diff --git a/test/genericlinearalgebra/eigh.jl b/test/genericlinearalgebra/eigh.jl new file mode 100644 index 00000000..7e602026 --- /dev/null +++ b/test/genericlinearalgebra/eigh.jl @@ -0,0 +1,93 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: LinearAlgebra, Diagonal, I +using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm +using GenericLinearAlgebra + +const eltypes = (BigFloat, Complex{BigFloat}) + +@testset "eigh_full! for T = $T" for T in eltypes + rng = StableRNG(123) + m = 54 + alg = GLA_QRIteration() + + A = randn(rng, T, m, m) + A = (A + A') / 2 + + D, V = @constinferred eigh_full(A; alg) + @test A * V ≈ V * D + @test isunitary(V) + @test all(isreal, D) + + D2, V2 = eigh_full!(copy(A), (D, V), alg) + @test D2 ≈ D + @test V2 ≈ V + + D3 = @constinferred eigh_vals(A, alg) + @test D ≈ Diagonal(D3) +end + +@testset "eigh_trunc! for T = $T" for T in eltypes + rng = StableRNG(123) + m = 54 + alg = GLA_QRIteration() + A = randn(rng, T, m, m) + A = A * A' + A = (A + A') / 2 + Ac = similar(A) + D₀ = reverse(eigh_vals(A)) + + r = m - 2 + s = 1 + sqrt(eps(real(T))) + atol = sqrt(eps(real(T))) + + D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc = truncrank(r)) + Dfull, Vfull = eigh_full(A; alg) + @test length(diagview(D1)) == r + @test isisometric(V1) + @test A * V1 ≈ V1 * D1 + @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] + @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + trunc = trunctol(; atol = s * D₀[r + 1]) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) + @test length(diagview(D2)) == r + @test isisometric(V2) + @test A * V2 ≈ V2 * D2 + @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + s = 1 - sqrt(eps(real(T))) + trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) + D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc) + @test length(diagview(D3)) == r + @test A * V3 ≈ V3 * D3 + @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + # test for same subspace + @test V1 * (V1' * V2) ≈ V2 + @test V2 * (V2' * V1) ≈ V1 + @test V1 * (V1' * V3) ≈ V3 + @test V3 * (V3' * V1) ≈ V1 +end + +@testset "eigh_trunc! specify truncation algorithm T = $T" for T in eltypes + rng = StableRNG(123) + m = 4 + atol = sqrt(eps(real(T))) + V = qr_compact(randn(rng, T, m, m))[1] + D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) + A = V * D * V' + A = (A + A') / 2 + alg = TruncatedAlgorithm(GLA_QRIteration(), truncrank(2)) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) + @test diagview(D2) ≈ diagview(D)[1:2] + @test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2)) + @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol + + alg = TruncatedAlgorithm(GLA_QRIteration(), truncerror(; atol = 0.2)) + D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg) + @test diagview(D3) ≈ diagview(D)[1:2] + @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol +end diff --git a/test/genericlinearalgebra/lq.jl b/test/genericlinearalgebra/lq.jl new file mode 100644 index 00000000..dc186dcb --- /dev/null +++ b/test/genericlinearalgebra/lq.jl @@ -0,0 +1,124 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: diag, I, Diagonal +using GenericLinearAlgebra + +eltypes = (BigFloat, Complex{BigFloat}) + +@testset "qr_compact! for T = $T" for T in eltypes + + rng = StableRNG(123) + m = 54 + for n in (37, m, 63) + minmn = min(m, n) + m = 54 + A = randn(rng, T, m, n) + L, Q = @constinferred lq_compact(A) + @test L isa Matrix{T} && size(L) == (m, minmn) + @test Q isa Matrix{T} && size(Q) == (minmn, n) + @test L * Q ≈ A + @test isisometric(Q; side = :right) + + Ac = similar(A) + L2, Q2 = @constinferred lq_compact!(copy!(Ac, A), (L, Q)) + @test L2 === L + @test Q2 === Q + + noL = similar(A, 0, minmn) + Q2 = similar(Q) + lq_compact!(copy!(Ac, A), (noL, Q2)) + @test Q == Q2 + + # Transposed QR algorithm + qr_alg = GLA_HouseholderQR() + lq_alg = LQViaTransposedQR(qr_alg) + L2, Q2 = @constinferred lq_compact!(copy!(Ac, A), (L, Q), lq_alg) + @test L2 === L + @test Q2 === Q + noL = similar(A, 0, minmn) + Q2 = similar(Q) + lq_compact!(copy!(Ac, A), (noL, Q2), lq_alg) + @test Q == Q2 + + @test_throws ArgumentError lq_compact(A; blocksize = 2) + @test_throws ArgumentError lq_compact(A; pivoted = true) + + # positive + lq_compact!(copy!(Ac, A), (L, Q); positive = true) + @test L * Q ≈ A + @test isisometric(Q; side = :right) + @test all(>=(zero(real(T))), real(diag(L))) + lq_compact!(copy!(Ac, A), (noL, Q2); positive = true) + @test Q == Q2 + end +end + +@testset "lq_full! for T = $T" for T in eltypes + rng = StableRNG(123) + m = 54 + for n in (37, m, 63) + minmn = min(m, n) + A = randn(rng, T, m, n) + L, Q = lq_full(A) + @test L isa Matrix{T} && size(L) == (m, n) + @test Q isa Matrix{T} && size(Q) == (n, n) + @test L * Q ≈ A + @test isunitary(Q) + + Ac = similar(A) + L2, Q2 = @constinferred lq_full!(copy!(Ac, A), (L, Q)) + @test L2 === L + @test Q2 === Q + @test L * Q ≈ A + @test isunitary(Q) + + noL = similar(A, 0, n) + Q2 = similar(Q) + lq_full!(copy!(Ac, A), (noL, Q2)) + @test Q[1:minmn, n] ≈ Q2[1:minmn, n] + + # Transposed QR algorithm + qr_alg = GLA_HouseholderQR() + lq_alg = LQViaTransposedQR(qr_alg) + L2, Q2 = @constinferred lq_full!(copy!(Ac, A), (L, Q), lq_alg) + @test L2 === L + @test Q2 === Q + @test L * Q ≈ A + @test Q * Q' ≈ I + noL = similar(A, 0, n) + Q2 = similar(Q) + lq_full!(copy!(Ac, A), (noL, Q2), lq_alg) + @test Q[1:minmn, n] ≈ Q2[1:minmn, n] + + # Argument errors for unsupported options + @test_throws ArgumentError lq_full(A; blocksize = 2) + @test_throws ArgumentError lq_full(A; pivoted = true) + + # positive + lq_full!(copy!(Ac, A), (L, Q); positive = true) + @test L * Q ≈ A + @test isunitary(Q) + @test all(>=(zero(real(T))), real(diag(L))) + lq_full!(copy!(Ac, A), (noL, Q2); positive = true) + @test Q[1:minmn, n] ≈ Q2[1:minmn, n] + + qr_alg = GLA_HouseholderQR(; positive = true) + lq_alg = LQViaTransposedQR(qr_alg) + lq_full!(copy!(Ac, A), (L, Q), lq_alg) + @test L * Q ≈ A + @test Q * Q' ≈ I + @test all(>=(zero(real(T))), real(diag(L))) + lq_full!(copy!(Ac, A), (noL, Q2), lq_alg) + @test Q[1:minmn, n] ≈ Q2[1:minmn, n] + + # positive and blocksize 1 + lq_full!(copy!(Ac, A), (L, Q); positive = true, blocksize = 1) + @test L * Q ≈ A + @test isunitary(Q) + @test all(>=(zero(real(T))), real(diag(L))) + lq_full!(copy!(Ac, A), (noL, Q2); positive = true, blocksize = 1) + @test Q[1:minmn, n] ≈ Q2[1:minmn, n] + end +end diff --git a/test/genericlinearalgebra/qr.jl b/test/genericlinearalgebra/qr.jl new file mode 100644 index 00000000..3ce530bb --- /dev/null +++ b/test/genericlinearalgebra/qr.jl @@ -0,0 +1,109 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: diag, I, Diagonal +using GenericLinearAlgebra + +eltypes = (BigFloat, Complex{BigFloat}) + +@testset "qr_compact! for T = $T" for T in eltypes + + rng = StableRNG(123) + m = 54 + for n in (37, m, 63) + minmn = min(m, n) + m = 54 + A = randn(rng, T, m, n) + Q, R = @constinferred qr_compact(A) + @test Q isa Matrix{T} && size(Q) == (m, minmn) + @test R isa Matrix{T} && size(R) == (minmn, n) + @test Q * R ≈ A + + Ac = similar(A) + Q2, R2 = @constinferred qr_compact!(copy!(Ac, A), (Q, R)) + @test Q2 === Q + @test R2 === R + + Q2 = similar(Q) + noR = similar(A, minmn, 0) + qr_compact!(copy!(Ac, A), (Q2, noR)) + @test Q == Q2 + + @test_throws ArgumentError qr_compact(A; blocksize = 2) + @test_throws ArgumentError qr_compact(A; pivoted = true) + + # positive + qr_compact!(copy!(Ac, A), (Q, R); positive = true) + @test Q * R ≈ A + @test isisometric(Q) + @test all(>=(zero(real(T))), real(diag(R))) + qr_compact!(copy!(Ac, A), (Q2, noR); positive = true) + @test Q == Q2 + end +end + +@testset "qr_full! for T = $T" for T in eltypes + rng = StableRNG(123) + m = 54 + for n in (37, m, 63) + minmn = min(m, n) + A = randn(rng, T, m, n) + Q, R = qr_full(A) + @test Q isa Matrix{T} && size(Q) == (m, m) + @test R isa Matrix{T} && size(R) == (m, n) + Qc, Rc = qr_compact(A) + @test Q * R ≈ A + @test isunitary(Q) + + Ac = similar(A) + Q2 = similar(Q) + noR = similar(A, m, 0) + Q2, R2 = @constinferred qr_full!(copy!(Ac, A), (Q, R)) + @test Q2 === Q + @test R2 === R + @test Q * R ≈ A + @test isunitary(Q) + qr_full!(copy!(Ac, A), (Q2, noR)) + @test Q == Q2 + + # unblocked algorithm + qr_full!(copy!(Ac, A), (Q, R); blocksize = 1) + @test Q * R ≈ A + @test isunitary(Q) + qr_full!(copy!(Ac, A), (Q2, noR); blocksize = 1) + @test Q == Q2 + if n == m + qr_full!(copy!(Q2, A), (Q2, noR); blocksize = 1) # in-place Q + @test Q ≈ Q2 + end + + # Argument errors for unsupported options + @test_throws ArgumentError qr_full(A; blocksize = 2) + @test_throws ArgumentError qr_compact(A; pivoted = true) + + # positive + qr_full!(copy!(Ac, A), (Q, R); positive = true) + @test Q * R ≈ A + @test isunitary(Q) + @test all(>=(zero(real(T))), real(diag(R))) + qr_full!(copy!(Ac, A), (Q2, noR); positive = true) + @test Q == Q2 + # positive and blocksize 1 + qr_full!(copy!(Ac, A), (Q, R); positive = true, blocksize = 1) + @test Q * R ≈ A + @test isunitary(Q) + @test all(>=(zero(real(T))), real(diag(R))) + qr_full!(copy!(Ac, A), (Q2, noR); positive = true, blocksize = 1) + @test Q == Q2 + if n <= m + # the following test tries to find the diagonal element (in order to test positivity) + # before the column permutation. This only works if all columns have a diagonal + # element + for j in 1:n + i = findlast(!iszero, view(R, :, j)) + @test real(R[i, j]) >= zero(real(T)) + end + end + end +end diff --git a/test/genericlinearalgebra/svd.jl b/test/genericlinearalgebra/svd.jl new file mode 100644 index 00000000..f7177e79 --- /dev/null +++ b/test/genericlinearalgebra/svd.jl @@ -0,0 +1,171 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef, norm +using MatrixAlgebraKit: TruncatedAlgorithm, diagview, isisometric +using GenericLinearAlgebra + +eltypes = (BigFloat, Complex{BigFloat}) + +@testset "svd_compact! for T = $T" for T in eltypes + rng = StableRNG(123) + m = 54 + @testset "size ($m, $n)" for n in (37, m, 63, 0) + k = min(m, n) + alg = GLA_QRIteration() + minmn = min(m, n) + A = randn(rng, T, m, n) + + if VERSION < v"1.11" + # This is type unstable on older versions of Julia. + U, S, Vᴴ = svd_compact(A; alg) + else + U, S, Vᴴ = @constinferred svd_compact(A; alg = ($alg)) + end + @test U isa Matrix{T} && size(U) == (m, minmn) + @test S isa Diagonal{real(T)} && size(S) == (minmn, minmn) + @test Vᴴ isa Matrix{T} && size(Vᴴ) == (minmn, n) + @test U * S * Vᴴ ≈ A + @test isisometric(U) + @test isisometric(Vᴴ; side = :right) + @test isposdef(S) + + Ac = similar(A) + Sc = similar(A, real(T), min(m, n)) + alg′ = @constinferred MatrixAlgebraKit.select_algorithm(svd_compact!, A, $alg) + U2, S2, V2ᴴ = @constinferred svd_compact!(copy!(Ac, A), (U, S, Vᴴ), alg′) + @test U2 ≈ U + @test S2 ≈ S + @test V2ᴴ ≈ Vᴴ + @test U * S * Vᴴ ≈ A + @test isisometric(U) + @test isisometric(Vᴴ; side = :right) + @test isposdef(S) + + Sd = @constinferred svd_vals(A, alg′) + @test S ≈ Diagonal(Sd) + end +end + +@testset "svd_full! for T = $T" for T in eltypes + rng = StableRNG(123) + m = 54 + @testset "size ($m, $n)" for n in (37, m, 63, 0) + alg = GLA_QRIteration() + A = randn(rng, T, m, n) + U, S, Vᴴ = svd_full(A; alg) + @test U isa Matrix{T} && size(U) == (m, m) + @test S isa Matrix{real(T)} && size(S) == (m, n) + @test Vᴴ isa Matrix{T} && size(Vᴴ) == (n, n) + @test U * S * Vᴴ ≈ A + @test isunitary(U) + @test isunitary(Vᴴ) + @test all(isposdef, diagview(S)) + + Ac = similar(A) + U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), (U, S, Vᴴ), alg) + @test U2 ≈ U + @test S2 ≈ S + @test V2ᴴ ≈ Vᴴ + @test U * S * Vᴴ ≈ A + @test isunitary(U) + @test isunitary(Vᴴ) + @test all(isposdef, diagview(S)) + + Sc = svd_vals!(copy!(Ac, A), alg) + @test diagview(S) ≈ Sc + end + @testset "size (0, 0)" begin + @testset "algorithm $alg" for alg in + (LAPACK_DivideAndConquer(), LAPACK_QRIteration()) + A = randn(rng, T, 0, 0) + U, S, Vᴴ = svd_full(A; alg) + @test U isa Matrix{T} && size(U) == (0, 0) + @test S isa Matrix{real(T)} && size(S) == (0, 0) + @test Vᴴ isa Matrix{T} && size(Vᴴ) == (0, 0) + @test U * S * Vᴴ ≈ A + @test isunitary(U) + @test isunitary(Vᴴ) + @test all(isposdef, diagview(S)) + end + end +end + +@testset "svd_trunc! for T = $T" for T in eltypes + rng = StableRNG(123) + m = 54 + atol = sqrt(eps(real(T))) + alg = GLA_QRIteration() + + @testset "size ($m, $n)" for n in (37, m, 63) + n > m && alg isa LAPACK_Jacobi && continue # not supported + A = randn(rng, T, m, n) + S₀ = svd_vals(A) + minmn = min(m, n) + r = minmn - 2 + + U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) + @test length(diagview(S1)) == r + @test diagview(S1) ≈ S₀[1:r] + @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] + # Test truncation error + @test ϵ1 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + + s = 1 + sqrt(eps(real(T))) + trunc = trunctol(; atol = s * S₀[r + 1]) + + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc) + @test length(diagview(S2)) == r + @test U1 ≈ U2 + @test S1 ≈ S2 + @test V1ᴴ ≈ V2ᴴ + @test ϵ2 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + + trunc = truncerror(; atol = s * norm(@view(S₀[(r + 1):end]))) + U3, S3, V3ᴴ, ϵ3 = @constinferred svd_trunc(A; alg, trunc) + @test length(diagview(S3)) == r + @test U1 ≈ U3 + @test S1 ≈ S3 + @test V1ᴴ ≈ V3ᴴ + @test ϵ3 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + end +end + +@testset "svd_trunc! mix maxrank and tol for T = $T" for T in eltypes + rng = StableRNG(123) + alg = GLA_QRIteration() + m = 4 + U = qr_compact(randn(rng, T, m, m))[1] + S = Diagonal(T[0.9, 0.3, 0.1, 0.01]) + Vᴴ = qr_compact(randn(rng, T, m, m))[1] + A = U * S * Vᴴ + + for trunc_fun in ( + (rtol, maxrank) -> (; rtol, maxrank), + (rtol, maxrank) -> truncrank(maxrank) & trunctol(; rtol), + ) + U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 1)) + @test length(diagview(S1)) == 1 + @test diagview(S1) ≈ diagview(S)[1:1] + + U2, S2, V2ᴴ, ϵ2 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 3)) + @test length(diagview(S2)) == 2 + @test diagview(S2) ≈ diagview(S)[1:2] + end +end + +@testset "svd_trunc! specify truncation algorithm T = $T" for T in eltypes + rng = StableRNG(123) + atol = sqrt(eps(real(T))) + m = 4 + U = qr_compact(randn(rng, T, m, m))[1] + S = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) + Vᴴ = qr_compact(randn(rng, T, m, m))[1] + A = U * S * Vᴴ + alg = TruncatedAlgorithm(GLA_QRIteration(), trunctol(; atol = 0.2)) + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg) + @test diagview(S2) ≈ diagview(S)[1:2] + @test ϵ2 ≈ norm(diagview(S)[3:4]) atol = atol + @test_throws ArgumentError svd_trunc(A; alg, trunc = (; maxrank = 2)) +end diff --git a/test/genericschur/eig.jl b/test/genericschur/eig.jl new file mode 100644 index 00000000..ce1e8f1b --- /dev/null +++ b/test/genericschur/eig.jl @@ -0,0 +1,116 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: Diagonal +using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm +using GenericSchur + +const eltypes = (BigFloat, Complex{BigFloat}) + +@testset "eig_full! for T = $T" for T in eltypes + rng = StableRNG(123) + m = 24 + alg = GS_QRIteration() + A = randn(rng, T, m, m) + Tc = complex(T) + + D, V = @constinferred eig_full(A; alg = ($alg)) + @test eltype(D) == eltype(V) == Tc + @test A * V ≈ V * D + + alg′ = @constinferred MatrixAlgebraKit.select_algorithm(eig_full!, A, $alg) + + Ac = similar(A) + D2, V2 = @constinferred eig_full!(copy!(Ac, A), (D, V), alg′) + @test D2 ≈ D + @test V2 ≈ V + @test A * V ≈ V * D + + Dc = @constinferred eig_vals(A, alg′) + @test eltype(Dc) == Tc + @test D ≈ Diagonal(Dc) +end + +@testset "eig_trunc! for T = $T" for T in eltypes + rng = StableRNG(123) + m = 6 + alg = GS_QRIteration() + A = randn(rng, T, m, m) + A *= A' # TODO: deal with eigenvalue ordering etc + # eigenvalues are sorted by ascending real component... + D₀ = sort!(eig_vals(A); by = abs, rev = true) + rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2)) + r = length(D₀) - rmin + atol = sqrt(eps(real(T))) + + D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r)) + D1base, V1base = @constinferred eig_full(A; alg) + + @test length(diagview(D1)) == r + @test A * V1 ≈ V1 * D1 + @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + s = 1 + sqrt(eps(real(T))) + trunc = trunctol(; atol = s * abs(D₀[r + 1])) + D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc) + @test length(diagview(D2)) == r + @test A * V2 ≈ V2 * D2 + @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + s = 1 - sqrt(eps(real(T))) + trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) + D3, V3, ϵ3 = @constinferred eig_trunc(A; alg, trunc) + @test length(diagview(D3)) == r + @test A * V3 ≈ V3 * D3 + @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + # trunctol keeps order, truncrank might not + # test for same subspace + @test V1 * ((V1' * V1) \ (V1' * V2)) ≈ V2 + @test V2 * ((V2' * V2) \ (V2' * V1)) ≈ V1 + @test V1 * ((V1' * V1) \ (V1' * V3)) ≈ V3 + @test V3 * ((V3' * V3) \ (V3' * V1)) ≈ V1 +end + +@testset "eig_trunc! specify truncation algorithm T = $T" for T in eltypes + rng = StableRNG(123) + m = 4 + atol = sqrt(eps(real(T))) + V = randn(rng, T, m, m) + D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) + A = V * D * inv(V) + alg = TruncatedAlgorithm(GS_QRIteration(), truncrank(2)) + D2, V2, ϵ2 = @constinferred eig_trunc(A; alg) + @test diagview(D2) ≈ diagview(D)[1:2] + @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol + @test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2)) + + alg = TruncatedAlgorithm(GS_QRIteration(), truncerror(; atol = 0.2, p = 1)) + D3, V3, ϵ3 = @constinferred eig_trunc(A; alg) + @test diagview(D3) ≈ diagview(D)[1:2] + @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol +end + +@testset "eig for Diagonal{$T}" for T in eltypes + rng = StableRNG(123) + m = 24 + Ad = randn(rng, T, m) + A = Diagonal(Ad) + atol = sqrt(eps(real(T))) + + D, V = @constinferred eig_full(A) + @test D isa Diagonal{T} && size(D) == size(A) + @test V isa Diagonal{T} && size(V) == size(A) + @test A * V ≈ V * D + + D2 = @constinferred eig_vals(A) + @test D2 isa AbstractVector{T} && length(D2) == m + @test diagview(D) ≈ D2 + + A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) + alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) + D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg) + @test diagview(D2) ≈ diagview(A2)[1:2] + @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol +end diff --git a/test/lq.jl b/test/lq.jl index 2c8dfefe..8de5a582 100644 --- a/test/lq.jl +++ b/test/lq.jl @@ -5,9 +5,10 @@ using StableRNGs using LinearAlgebra: diag, I, Diagonal using MatrixAlgebraKit: LQViaTransposedQR, LAPACK_HouseholderQR -eltypes = (Float32, Float64, ComplexF32, ComplexF64) +BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +GenericFloats = (Float16, BigFloat, Complex{BigFloat}) -@testset "lq_compact! for T = $T" for T in eltypes +@testset "lq_compact! for T = $T" for T in BLASFloats rng = StableRNG(123) m = 54 for n in (37, m, 63) @@ -114,7 +115,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) end end -@testset "lq_full! for T = $T" for T in eltypes +@testset "lq_full! for T = $T" for T in BLASFloats rng = StableRNG(123) m = 54 for n in (37, m, 63) @@ -208,7 +209,7 @@ end end end -@testset "lq_compact, lq_full and lq_null for Diagonal{$T}" for T in eltypes +@testset "lq_compact, lq_full and lq_null for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) rng = StableRNG(123) atol = eps(real(T))^(3 / 4) for m in (54, 0) diff --git a/test/qr.jl b/test/qr.jl index 826c320b..c4f0c9d6 100644 --- a/test/qr.jl +++ b/test/qr.jl @@ -4,9 +4,10 @@ using TestExtras using StableRNGs using LinearAlgebra: diag, I, Diagonal -eltypes = (Float32, Float64, ComplexF32, ComplexF64) +BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +GenericFloats = (Float16, BigFloat, Complex{BigFloat}) -@testset "qr_compact! and qr_null! for T = $T" for T in eltypes +@testset "qr_compact! and qr_null! for T = $T" for T in BLASFloats rng = StableRNG(123) m = 54 for n in (37, m, 63) @@ -99,7 +100,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) end end -@testset "qr_full! for T = $T" for T in eltypes +@testset "qr_full! for T = $T" for T in BLASFloats rng = StableRNG(123) m = 54 for n in (37, m, 63) @@ -176,7 +177,7 @@ end end end -@testset "qr_compact, qr_full and qr_null for Diagonal{$T}" for T in eltypes +@testset "qr_compact, qr_full and qr_null for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) rng = StableRNG(123) atol = eps(real(T))^(3 / 4) for m in (54, 0) diff --git a/test/runtests.jl b/test/runtests.jl index af4996ed..ec255538 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -107,3 +107,20 @@ if AMDGPU.functional() include("amd/orthnull.jl") end end + +using GenericLinearAlgebra +@safetestset "QR / LQ Decomposition" begin + include("genericlinearalgebra/qr.jl") + include("genericlinearalgebra/lq.jl") +end +@safetestset "Singular Value Decomposition" begin + include("genericlinearalgebra/svd.jl") +end +@safetestset "Hermitian Eigenvalue Decomposition" begin + include("genericlinearalgebra/eigh.jl") +end + +using GenericSchur +@safetestset "General Eigenvalue Decomposition" begin + include("genericschur/eig.jl") +end diff --git a/test/svd.jl b/test/svd.jl index acb27946..d055f866 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -5,7 +5,8 @@ using StableRNGs using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef, norm using MatrixAlgebraKit: TruncatedAlgorithm, diagview, isisometric -const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +GenericFloats = (Float16, BigFloat, Complex{BigFloat}) @testset "svd_compact! for T = $T" for T in BLASFloats rng = StableRNG(123) @@ -202,7 +203,7 @@ end @test_throws ArgumentError svd_trunc(A; alg, trunc = (; maxrank = 2)) end -@testset "svd for Diagonal{$T}" for T in BLASFloats +@testset "svd for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) rng = StableRNG(123) atol = sqrt(eps(real(T))) for m in (54, 0)