diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 4263c592..9e7e6600 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -152,6 +152,14 @@ end function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) check_input(svd_compact!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ + m, n = size(A) + minmn = min(m, n) + if minmn == 0 + one!(U) + zero!(S) + one!(Vᴴ) + return USVᴴ + end do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) @@ -181,6 +189,12 @@ end function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) check_input(svd_vals!, A, S, alg) + m, n = size(A) + minmn = min(m, n) + if minmn == 0 + zero!(S) + return S + end U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) @@ -256,6 +270,12 @@ function svd_compact!(A, USVᴴ, alg::DiagonalAlgorithm) end function svd_vals!(A::AbstractMatrix, S, alg::DiagonalAlgorithm) check_input(svd_vals!, A, S, alg) + m, n = size(A) + minmn = min(m, n) + if minmn == 0 + zero!(S) + return S + end Ad = diagview(A) S .= abs.(Ad) sort!(S; rev = true) @@ -407,6 +427,14 @@ end function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) check_input(svd_compact!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ + m, n = size(A) + minmn = min(m, n) + if minmn == 0 + one!(U) + zero!(S) + one!(Vᴴ) + return USVᴴ + end do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) @@ -431,6 +459,12 @@ _largest(x, y) = abs(x) < abs(y) ? y : x function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm) check_input(svd_vals!, A, S, alg) + m, n = size(A) + minmn = min(m, n) + if minmn == 0 + zero!(S) + return S + end U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) diff --git a/test/amd/svd.jl b/test/amd/svd.jl deleted file mode 100644 index fcd5b490..00000000 --- a/test/amd/svd.jl +++ /dev/null @@ -1,157 +0,0 @@ -using MatrixAlgebraKit -using MatrixAlgebraKit: diagview -using LinearAlgebra: Diagonal, isposdef -using Test -using TestExtras -using StableRNGs -using AMDGPU - -include(joinpath("..", "utilities.jl")) - -BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) - -@testset "svd_compact! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m, 63) - k = min(m, n) - algs(::ROCArray) = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi()) - algs(::Diagonal) = (DiagonalAlgorithm(),) - As = m == n ? (ROCArray(randn(rng, T, m, n)), Diagonal(ROCArray(randn(rng, T, m)))) : (ROCArray(randn(rng, T, m, n)),) - for A in As - @testset "algorithm $alg" for alg in algs(A) - minmn = min(m, n) - - U, S, Vᴴ = svd_compact(A; alg) - @test U isa ROCMatrix{T} && size(U) == (m, minmn) - @test S isa Diagonal{real(T), <:ROCVector} && size(S) == (minmn, minmn) - @test Vᴴ isa ROCMatrix{T} && size(Vᴴ) == (minmn, n) - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(Vᴴ * Vᴴ') - @test isposdef(S) - - Ac = similar(A) - U2, S2, V2ᴴ = @constinferred svd_compact!(copy!(Ac, A), (U, S, Vᴴ), alg) - @test U2 === U - @test S2 === S - @test V2ᴴ === Vᴴ - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(Vᴴ * Vᴴ') - @test isposdef(S) - - Sd = svd_vals(A, alg) - @test ROCArray(diagview(S)) ≈ Sd - # ROCArray is necessary because norm of ROCArray view with non-unit step is broken - if alg isa ROCSOLVER_QRIteration - @test_warn "invalid keyword arguments for GPU_QRIteration" svd_compact!(copy!(Ac, A), (U, S, Vᴴ), ROCSOLVER_QRIteration(; bad = "bad")) - end - end - end - end -end - -@testset "svd_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - algs(::ROCArray) = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi()) - algs(::Diagonal) = (DiagonalAlgorithm(),) - @testset "size ($m, $n)" for n in (37, m, 63) - As = m == n ? (ROCArray(randn(rng, T, m, n)), Diagonal(ROCArray(randn(rng, T, m)))) : (ROCArray(randn(rng, T, m, n)),) - for A in As - @testset "algorithm $alg" for alg in algs(A) - U, S, Vᴴ = svd_full(A; alg) - @test U isa ROCMatrix{T} && size(U) == (m, m) - if A isa Diagonal - @test S isa Diagonal{real(T), <:ROCVector{real(T)}} && size(S) == (m, n) - else - @test S isa ROCMatrix{real(T)} && size(S) == (m, n) - end - @test Vᴴ isa ROCMatrix{T} && size(Vᴴ) == (n, n) - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(U * U') - @test isapproxone(Vᴴ * Vᴴ') - @test isapproxone(Vᴴ' * Vᴴ) - @test all(isposdef, diagview(S)) - - Ac = similar(A) - U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), (U, S, Vᴴ), alg) - @test U2 === U - @test S2 === S - @test V2ᴴ === Vᴴ - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(U * U') - @test isapproxone(Vᴴ * Vᴴ') - @test isapproxone(Vᴴ' * Vᴴ) - @test all(isposdef, diagview(S)) - - Sc = similar(A, real(T), min(m, n)) - Sc2 = svd_vals!(copy!(Ac, A), Sc, alg) - @test Sc === Sc2 - @test ROCArray(diagview(S)) ≈ Sc - # ROCArray is necessary because norm of ROCArray view with non-unit step is broken - if alg isa ROCSOLVER_QRIteration - @test_warn "invalid keyword arguments for GPU_QRIteration" svd_full!(copy!(Ac, A), (U, S, Vᴴ), ROCSOLVER_QRIteration(; bad = "bad")) - @test_warn "invalid keyword arguments for GPU_QRIteration" svd_vals!(copy!(Ac, A), Sc, ROCSOLVER_QRIteration(; bad = "bad")) - end - end - end - end - @testset "size (0, 0)" begin - for A in (ROCArray(randn(rng, T, 0, 0)), Diagonal(ROCArray(randn(rng, T, 0)))) - @testset "algorithm $alg" for alg in algs(A) - U, S, Vᴴ = svd_full(A; alg) - @test U isa ROCMatrix{T} && size(U) == (0, 0) - if isa(A, Diagonal) - @test S isa Diagonal{real(T), <:ROCVector{real(T)}} - else - @test S isa ROCMatrix{real(T)} - end - @test Vᴴ isa ROCMatrix{T} && size(Vᴴ) == (0, 0) - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(U * U') - @test isapproxone(Vᴴ * Vᴴ') - @test isapproxone(Vᴴ' * Vᴴ) - @test all(isposdef, diagview(S)) - end - end - end -end - -# @testset "svd_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) -# rng = StableRNG(123) -# m = 54 -# if LinearAlgebra.LAPACK.version() < v"3.12.0" -# algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) -# else -# algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), -# LAPACK_Jacobi()) -# end -# -# @testset "size ($m, $n)" for n in (37, m, 63) -# @testset "algorithm $alg" for alg in algs -# n > m && alg isa LAPACK_Jacobi && continue # not supported -# A = randn(rng, T, m, n) -# S₀ = svd_vals(A) -# minmn = min(m, n) -# r = minmn - 2 -# -# U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc=truncrank(r)) -# @test length(S1.diag) == r -# @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] -# -# s = 1 + sqrt(eps(real(T))) -# trunc2 = trunctol(; atol=s * S₀[r + 1]) -# -# U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc=trunctol(; atol=s * S₀[r + 1])) -# @test length(S2.diag) == r -# @test U1 ≈ U2 -# @test S1 ≈ S2 -# @test V1ᴴ ≈ V2ᴴ -# end -# end -# end diff --git a/test/cuda/svd.jl b/test/cuda/svd.jl deleted file mode 100644 index 8d931b3b..00000000 --- a/test/cuda/svd.jl +++ /dev/null @@ -1,170 +0,0 @@ -using MatrixAlgebraKit -using MatrixAlgebraKit: diagview -using LinearAlgebra: Diagonal, isposdef, norm, opnorm -using Test -using TestExtras -using StableRNGs -using CUDA - -include(joinpath("..", "utilities.jl")) - -BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) - -@testset "svd_compact! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m, 63) - k = min(m, n) - algs(::CuArray) = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi()) - algs(::Diagonal) = (DiagonalAlgorithm(),) - As = m == n ? (CuArray(randn(rng, T, m, n)), Diagonal(CuArray(randn(rng, T, m)))) : (CuArray(randn(rng, T, m, n)),) - for A in As - @testset "algorithm $alg" for alg in algs(A) - minmn = min(m, n) - U, S, Vᴴ = svd_compact(A; alg) - @test U isa CuMatrix{T} && size(U) == (m, minmn) - @test S isa Diagonal{real(T), <:CuVector} && size(S) == (minmn, minmn) - @test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (minmn, n) - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(Vᴴ * Vᴴ') - @test isposdef(S) - - Ac = similar(A) - U2, S2, V2ᴴ = @constinferred svd_compact!(copy!(Ac, A), (U, S, Vᴴ), alg) - @test U2 === U - @test S2 === S - @test V2ᴴ === Vᴴ - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(Vᴴ * Vᴴ') - @test isposdef(S) - - Sd = svd_vals(A, alg) - @test CuArray(diagview(S)) ≈ Sd - # CuArray is necessary because norm of CuArray view with non-unit step is broken - if alg isa CUSOLVER_QRIteration - @test_warn "invalid keyword arguments for GPU_QRIteration" svd_compact!(copy!(Ac, A), (U, S, Vᴴ), CUSOLVER_QRIteration(; bad = "bad")) - end - end - end - end -end - -@testset "svd_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - algs(::CuArray) = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi()) - algs(::Diagonal) = (DiagonalAlgorithm(),) - @testset "size ($m, $n)" for n in (37, m, 63) - As = m == n ? (CuArray(randn(rng, T, m, n)), Diagonal(CuArray(randn(rng, T, m)))) : (CuArray(randn(rng, T, m, n)),) - for A in As - @testset "algorithm $alg" for alg in algs(A) - minmn = min(m, n) - U, S, Vᴴ = svd_full(A; alg) - @test U isa CuMatrix{T} && size(U) == (m, m) - if A isa Diagonal - @test S isa Diagonal{real(T), <:CuVector{real(T)}} && size(S) == (m, n) - else - @test S isa CuMatrix{real(T)} && size(S) == (m, n) - end - @test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (n, n) - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(U * U') - @test isapproxone(Vᴴ * Vᴴ') - @test isapproxone(Vᴴ' * Vᴴ) - @test all(isposdef, diagview(S)) - - Ac = similar(A) - U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), (U, S, Vᴴ), alg) - @test U2 === U - @test S2 === S - @test V2ᴴ === Vᴴ - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(U * U') - @test isapproxone(Vᴴ * Vᴴ') - @test isapproxone(Vᴴ' * Vᴴ) - @test all(isposdef, diagview(S)) - - minmn = min(m, n) - Sc = similar(A, real(T), minmn) - Sc2 = svd_vals!(copy!(Ac, A), Sc, alg) - @test Sc === Sc2 - @test CuArray(diagview(S)) ≈ Sc - # CuArray is necessary because norm of CuArray view with non-unit step is broken - if alg isa CUSOLVER_QRIteration - @test_warn "invalid keyword arguments for GPU_QRIteration" svd_full!(copy!(Ac, A), (U, S, Vᴴ), CUSOLVER_QRIteration(; bad = "bad")) - @test_warn "invalid keyword arguments for GPU_QRIteration" svd_vals!(copy!(Ac, A), Sc, CUSOLVER_QRIteration(; bad = "bad")) - end - end - end - end - @testset "size (0, 0)" begin - for A in (CuArray(randn(rng, T, 0, 0)), Diagonal(CuArray(randn(rng, T, 0)))) - @testset "algorithm $alg" for alg in algs(A) - U, S, Vᴴ = svd_full(A; alg) - @test U isa CuMatrix{T} && size(U) == (0, 0) - @test size(S) == (0, 0) - if isa(A, Diagonal) - @test S isa Diagonal{real(T), <:CuVector{real(T)}} - else - @test S isa CuMatrix{real(T)} - end - @test Vᴴ isa CuMatrix{T} && size(Vᴴ) == (0, 0) - @test U * S * Vᴴ ≈ A - @test isapproxone(U' * U) - @test isapproxone(U * U') - @test isapproxone(Vᴴ * Vᴴ') - @test isapproxone(Vᴴ' * Vᴴ) - @test all(isposdef, diagview(S)) - end - end - end -end - -@testset "svd_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m, 63) - k = min(m, n) - 20 - p = min(m, n) - k - 1 - algs(::CuArray) = (CUSOLVER_QRIteration(), CUSOLVER_SVDPolar(), CUSOLVER_Jacobi(), CUSOLVER_Randomized(; k = k, p = p, niters = 100)) - algs(::Diagonal) = (DiagonalAlgorithm(),) - hAs = m == n ? (randn(rng, T, m, n), Diagonal(randn(rng, T, m))) : (randn(rng, T, m, n),) - minmn = min(m, n) - for hA in hAs - A = CuArray(hA) - @testset "algorithm $alg" for alg in algs(A) - S₀ = svd_vals(hA) - r = k - - U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) - @test length(S1.diag) == r - @test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] - @test norm(A - U1 * S1 * V1ᴴ) ≈ ϵ1 - U1, S1, V1ᴴ = @constinferred svd_trunc_no_error(A; alg, trunc = truncrank(r)) - @test length(S1.diag) == r - @test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] - - if !(alg isa CUSOLVER_Randomized) - s = 1 + sqrt(eps(real(T))) - trunc2 = trunctol(; atol = s * S₀[r + 1]) - - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc = trunctol(; atol = s * S₀[r + 1])) - @test length(S2.diag) == r - @test U1 ≈ U2 - @test parent(S1) ≈ parent(S2) - @test V1ᴴ ≈ V2ᴴ - - U2, S2, V2ᴴ = @constinferred svd_trunc_no_error(A; alg, trunc = trunctol(; atol = s * S₀[r + 1])) - @test length(S2.diag) == r - @test U1 ≈ U2 - @test parent(S1) ≈ parent(S2) - @test V1ᴴ ≈ V2ᴴ - end - end - end - end -end diff --git a/test/genericlinearalgebra/svd.jl b/test/genericlinearalgebra/svd.jl deleted file mode 100644 index 9cfdabf7..00000000 --- a/test/genericlinearalgebra/svd.jl +++ /dev/null @@ -1,172 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef, norm -using MatrixAlgebraKit: TruncatedAlgorithm, diagview, isisometric -using GenericLinearAlgebra - -eltypes = (BigFloat, Complex{BigFloat}) - -@testset "svd_compact! for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m, 63, 0) - k = min(m, n) - alg = GLA_QRIteration() - minmn = min(m, n) - A = randn(rng, T, m, n) - - if VERSION < v"1.11" - # This is type unstable on older versions of Julia. - U, S, Vᴴ = svd_compact(A; alg) - else - U, S, Vᴴ = @constinferred svd_compact(A; alg = ($alg)) - end - @test U isa Matrix{T} && size(U) == (m, minmn) - @test S isa Diagonal{real(T)} && size(S) == (minmn, minmn) - @test Vᴴ isa Matrix{T} && size(Vᴴ) == (minmn, n) - @test U * S * Vᴴ ≈ A - @test isisometric(U) - @test isisometric(Vᴴ; side = :right) - @test isposdef(S) - - Ac = similar(A) - Sc = similar(A, real(T), min(m, n)) - alg′ = @constinferred MatrixAlgebraKit.select_algorithm(svd_compact!, A, $alg) - U2, S2, V2ᴴ = @constinferred svd_compact!(copy!(Ac, A), (U, S, Vᴴ), alg′) - @test U2 ≈ U - @test S2 ≈ S - @test V2ᴴ ≈ Vᴴ - @test U * S * Vᴴ ≈ A - @test isisometric(U) - @test isisometric(Vᴴ; side = :right) - @test isposdef(S) - - Sd = @constinferred svd_vals(A, alg′) - @test S ≈ Diagonal(Sd) - end -end - -@testset "svd_full! for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m, 63, 0) - alg = GLA_QRIteration() - A = randn(rng, T, m, n) - U, S, Vᴴ = svd_full(A; alg) - @test U isa Matrix{T} && size(U) == (m, m) - @test S isa Matrix{real(T)} && size(S) == (m, n) - @test Vᴴ isa Matrix{T} && size(Vᴴ) == (n, n) - @test U * S * Vᴴ ≈ A - @test isunitary(U) - @test isunitary(Vᴴ) - @test all(isposdef, diagview(S)) - - Ac = similar(A) - U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), (U, S, Vᴴ), alg) - @test U2 ≈ U - @test S2 ≈ S - @test V2ᴴ ≈ Vᴴ - @test U * S * Vᴴ ≈ A - @test isunitary(U) - @test isunitary(Vᴴ) - @test all(isposdef, diagview(S)) - - Sc = svd_vals!(copy!(Ac, A), alg) - @test diagview(S) ≈ Sc - end - @testset "size (0, 0)" begin - @testset "algorithm $alg" for alg in - (LAPACK_DivideAndConquer(), LAPACK_QRIteration()) - A = randn(rng, T, 0, 0) - U, S, Vᴴ = svd_full(A; alg) - @test U isa Matrix{T} && size(U) == (0, 0) - @test S isa Matrix{real(T)} && size(S) == (0, 0) - @test Vᴴ isa Matrix{T} && size(Vᴴ) == (0, 0) - @test U * S * Vᴴ ≈ A - @test isunitary(U) - @test isunitary(Vᴴ) - @test all(isposdef, diagview(S)) - end - end -end - -@testset "svd_trunc! for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - atol = sqrt(eps(real(T))) - alg = GLA_QRIteration() - - @testset "size ($m, $n)" for n in (37, m, 63) - n > m && alg isa LAPACK_Jacobi && continue # not supported - A = randn(rng, T, m, n) - S₀ = svd_vals(A) - minmn = min(m, n) - r = minmn - 2 - - U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) - @test length(diagview(S1)) == r - @test diagview(S1) ≈ S₀[1:r] - @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] - # Test truncation error - @test ϵ1 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol - - s = 1 + sqrt(eps(real(T))) - trunc = trunctol(; atol = s * S₀[r + 1]) - - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc) - @test length(diagview(S2)) == r - @test U1 ≈ U2 - @test S1 ≈ S2 - @test V1ᴴ ≈ V2ᴴ - @test ϵ2 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol - - trunc = truncerror(; atol = s * norm(@view(S₀[(r + 1):end]))) - U3, S3, V3ᴴ, ϵ3 = @constinferred svd_trunc(A; alg, trunc) - @test length(diagview(S3)) == r - @test U1 ≈ U3 - @test S1 ≈ S3 - @test V1ᴴ ≈ V3ᴴ - @test ϵ3 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol - end -end - -@testset "svd_trunc! mix maxrank and tol for T = $T" for T in eltypes - rng = StableRNG(123) - alg = GLA_QRIteration() - m = 4 - U = qr_compact(randn(rng, T, m, m))[1] - S = Diagonal(T[0.9, 0.3, 0.1, 0.01]) - Vᴴ = qr_compact(randn(rng, T, m, m))[1] - A = U * S * Vᴴ - - for trunc_fun in ( - (rtol, maxrank) -> (; rtol, maxrank), - (rtol, maxrank) -> truncrank(maxrank) & trunctol(; rtol), - ) - U1, S1, V1ᴴ = svd_trunc_no_error(A; alg, trunc = trunc_fun(0.2, 1)) - @test length(diagview(S1)) == 1 - @test diagview(S1) ≈ diagview(S)[1:1] - - U2, S2, V2ᴴ = svd_trunc_no_error(A; alg, trunc = trunc_fun(0.2, 3)) - @test length(diagview(S2)) == 2 - @test diagview(S2) ≈ diagview(S)[1:2] - end -end - -@testset "svd_trunc! specify truncation algorithm T = $T" for T in eltypes - rng = StableRNG(123) - atol = sqrt(eps(real(T))) - m = 4 - U = qr_compact(randn(rng, T, m, m))[1] - S = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) - Vᴴ = qr_compact(randn(rng, T, m, m))[1] - A = U * S * Vᴴ - alg = TruncatedAlgorithm(GLA_QRIteration(), trunctol(; atol = 0.2)) - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg) - @test diagview(S2) ≈ diagview(S)[1:2] - @test ϵ2 ≈ norm(diagview(S)[3:4]) atol = atol - @test_throws ArgumentError svd_trunc(A; alg, trunc = (; maxrank = 2)) - @test_throws ArgumentError svd_trunc_no_error(A; alg, trunc = (; maxrank = 2)) -end diff --git a/test/lq.jl b/test/lq.jl index 4da1fc7a..1a7e7c90 100644 --- a/test/lq.jl +++ b/test/lq.jl @@ -51,9 +51,9 @@ for T in (BLASFloats..., GenericFloats...), n in (37, m, 63) ) TestSuite.test_lq_algs(T, (m, n), LAPACK_LQ_ALGS) elseif T ∈ GenericFloats - TestSuite.test_lq(T, (m, n); test_null = true, test_pivoted = false, test_blocksize = false) + TestSuite.test_lq(T, (m, n); test_pivoted = false, test_blocksize = false) GLA_LQ_ALGS = (LQViaTransposedQR(GLA_HouseholderQR()),) - TestSuite.test_lq_algs(T, (m, n), GLA_LQ_ALGS; test_null = true) + TestSuite.test_lq_algs(T, (m, n), GLA_LQ_ALGS) end if m == n AT = Diagonal{T, Vector{T}} diff --git a/test/qr.jl b/test/qr.jl index 3131349a..a0cd6d65 100644 --- a/test/qr.jl +++ b/test/qr.jl @@ -50,9 +50,9 @@ for T in (BLASFloats..., GenericFloats...), n in (37, m, 63) ) TestSuite.test_qr_algs(T, (m, n), LAPACK_QR_ALGS) elseif T ∈ GenericFloats - TestSuite.test_qr(T, (m, n); test_null = true, test_pivoted = false, test_blocksize = false) + TestSuite.test_qr(T, (m, n); test_pivoted = false, test_blocksize = false) GLA_QR_ALGS = (GLA_HouseholderQR(),) - TestSuite.test_qr_algs(T, (m, n), GLA_QR_ALGS; test_null = false) + TestSuite.test_qr_algs(T, (m, n), GLA_QR_ALGS) end if m == n AT = Diagonal{T, Vector{T}} diff --git a/test/runtests.jl b/test/runtests.jl index 9fe64002..8232b2ae 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,9 +10,6 @@ if !is_buildkite @safetestset "Truncate" begin include("truncate.jl") end - @safetestset "Singular Value Decomposition" begin - include("svd.jl") - end @safetestset "Generalized Eigenvalue Decomposition" begin include("gen_eig.jl") end @@ -34,11 +31,6 @@ if !is_buildkite JET.test_package(MatrixAlgebraKit; target_defined_modules = true) end end - - using GenericLinearAlgebra - @safetestset "Singular Value Decomposition" begin - include("genericlinearalgebra/svd.jl") - end end @safetestset "QR / LQ Decomposition" begin @@ -63,17 +55,6 @@ end @safetestset "Image and Null Space" begin include("orthnull.jl") end - -using CUDA -if CUDA.functional() - @safetestset "CUDA SVD" begin - include("cuda/svd.jl") - end -end - -using AMDGPU -if AMDGPU.functional() - @safetestset "AMDGPU SVD" begin - include("amd/svd.jl") - end +@safetestset "Singular Value Decomposition" begin + include("svd.jl") end diff --git a/test/svd.jl b/test/svd.jl index a41e075c..40d0528e 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -2,244 +2,62 @@ using MatrixAlgebraKit using Test using TestExtras using StableRNGs -using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef, norm -using MatrixAlgebraKit: TruncatedAlgorithm, diagview, isisometric +using LinearAlgebra: Diagonal +using CUDA, AMDGPU BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) -GenericFloats = (Float16, BigFloat, Complex{BigFloat}) - -@testset "svd_compact! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m, 63, 0) - k = min(m, n) - if LinearAlgebra.LAPACK.version() < v"3.12.0" - algs = ( - LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), - LAPACK_DivideAndConquer, :LAPACK_DivideAndConquer, - ) - else - algs = ( - LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), - LAPACK_Jacobi(), LAPACK_DivideAndConquer, :LAPACK_DivideAndConquer, +GenericFloats = (BigFloat, Complex{BigFloat}) + +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +for T in (BLASFloats..., GenericFloats...), m in (0, 54), n in (0, 37, m, 63) + TestSuite.seed_rng!(123) + if T ∈ BLASFloats + if CUDA.functional() + TestSuite.test_svd(CuMatrix{T}, (m, n); test_trunc = false) + CUDA_SVD_ALGS = ( + CUSOLVER_QRIteration(), + CUSOLVER_SVDPolar(), + CUSOLVER_Jacobi(), ) - end - @testset "algorithm $alg" for alg in algs - n > m && alg isa LAPACK_Jacobi && continue # not supported - minmn = min(m, n) - A = randn(rng, T, m, n) - - if VERSION < v"1.11" - # This is type unstable on older versions of Julia. - U, S, Vᴴ = svd_compact(A; alg) - else - U, S, Vᴴ = @constinferred svd_compact(A; alg = ($alg)) + TestSuite.test_svd_algs(CuMatrix{T}, (m, n), CUDA_SVD_ALGS; test_trunc = false) + if n == m + TestSuite.test_svd(Diagonal{T, CuVector{T}}, m; test_trunc = false) + TestSuite.test_svd_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false) end - @test U isa Matrix{T} && size(U) == (m, minmn) - @test S isa Diagonal{real(T)} && size(S) == (minmn, minmn) - @test Vᴴ isa Matrix{T} && size(Vᴴ) == (minmn, n) - @test U * S * Vᴴ ≈ A - @test isisometric(U) - @test isisometric(Vᴴ; side = :right) - @test isposdef(S) - - Ac = similar(A) - Sc = similar(A, real(T), min(m, n)) - alg′ = @constinferred MatrixAlgebraKit.select_algorithm(svd_compact!, A, $alg) - U2, S2, V2ᴴ = @constinferred svd_compact!(copy!(Ac, A), (U, S, Vᴴ), alg′) - @test U2 === U - @test S2 === S - @test V2ᴴ === Vᴴ - @test U * S * Vᴴ ≈ A - @test isisometric(U) - @test isisometric(Vᴴ; side = :right) - @test isposdef(S) - - Sd = @constinferred svd_vals(A, alg′) - @test S ≈ Diagonal(Sd) - end - end -end - -@testset "svd_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - @testset "size ($m, $n)" for n in (37, m, 63, 0) - @testset "algorithm $alg" for alg in - (LAPACK_DivideAndConquer(), LAPACK_QRIteration()) - A = randn(rng, T, m, n) - U, S, Vᴴ = svd_full(A; alg) - @test U isa Matrix{T} && size(U) == (m, m) - @test S isa Matrix{real(T)} && size(S) == (m, n) - @test Vᴴ isa Matrix{T} && size(Vᴴ) == (n, n) - @test U * S * Vᴴ ≈ A - @test isunitary(U) - @test isunitary(Vᴴ) - @test all(isposdef, diagview(S)) - - Ac = similar(A) - U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), (U, S, Vᴴ), alg) - @test U2 === U - @test S2 === S - @test V2ᴴ === Vᴴ - @test U * S * Vᴴ ≈ A - @test isunitary(U) - @test isunitary(Vᴴ) - @test all(isposdef, diagview(S)) - - Sc = similar(A, real(T), min(m, n)) - Sc2 = svd_vals!(copy!(Ac, A), Sc, alg) - @test Sc === Sc2 - @test diagview(S) ≈ Sc end - end - @testset "size (0, 0)" begin - @testset "algorithm $alg" for alg in - (LAPACK_DivideAndConquer(), LAPACK_QRIteration()) - A = randn(rng, T, 0, 0) - U, S, Vᴴ = svd_full(A; alg) - @test U isa Matrix{T} && size(U) == (0, 0) - @test S isa Matrix{real(T)} && size(S) == (0, 0) - @test Vᴴ isa Matrix{T} && size(Vᴴ) == (0, 0) - @test U * S * Vᴴ ≈ A - @test isunitary(U) - @test isunitary(Vᴴ) - @test all(isposdef, diagview(S)) - end - end -end - -@testset "svd_trunc! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - atol = sqrt(eps(real(T))) - if LinearAlgebra.LAPACK.version() < v"3.12.0" - algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) - else - algs = ( - LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), LAPACK_Jacobi(), - ) - end - - @testset "size ($m, $n)" for n in (37, m, 63) - @testset "algorithm $alg" for alg in algs - n > m && alg isa LAPACK_Jacobi && continue # not supported - A = randn(rng, T, m, n) - S₀ = svd_vals(A) - minmn = min(m, n) - r = minmn - 2 - - U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) - @test length(diagview(S1)) == r - @test diagview(S1) ≈ S₀[1:r] - @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] - # Test truncation error - @test ϵ1 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol - - s = 1 + sqrt(eps(real(T))) - trunc = trunctol(; atol = s * S₀[r + 1]) - - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc) - @test length(diagview(S2)) == r - @test U1 ≈ U2 - @test S1 ≈ S2 - @test V1ᴴ ≈ V2ᴴ - @test ϵ2 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol - - trunc = truncerror(; atol = s * norm(@view(S₀[(r + 1):end]))) - U3, S3, V3ᴴ, ϵ3 = @constinferred svd_trunc(A; alg, trunc) - @test length(diagview(S3)) == r - @test U1 ≈ U3 - @test S1 ≈ S3 - @test V1ᴴ ≈ V3ᴴ - @test ϵ3 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + if AMDGPU.functional() + TestSuite.test_svd(ROCMatrix{T}, (m, n); test_trunc = false) + AMD_SVD_ALGS = ( + ROCSOLVER_QRIteration(), + ROCSOLVER_Jacobi(), + ) + TestSuite.test_svd_algs(ROCMatrix{T}, (m, n), AMD_SVD_ALGS; test_trunc = false) + if n == m + TestSuite.test_svd(Diagonal{T, ROCVector{T}}, m; test_trunc = false) + TestSuite.test_svd_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false) + end end end -end - -@testset "svd_trunc! mix maxrank and tol for T = $T" for T in BLASFloats - rng = StableRNG(123) - if LinearAlgebra.LAPACK.version() < v"3.12.0" - algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) - else - algs = ( - LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), LAPACK_Jacobi(), - ) - end - m = 4 - @testset "algorithm $alg" for alg in algs - U = qr_compact(randn(rng, T, m, m))[1] - S = Diagonal(T[0.9, 0.3, 0.1, 0.01]) - Vᴴ = qr_compact(randn(rng, T, m, m))[1] - A = U * S * Vᴴ - - for trunc_fun in ( - (rtol, maxrank) -> (; rtol, maxrank), - (rtol, maxrank) -> truncrank(maxrank) & trunctol(; rtol), + if !is_buildkite + if T ∈ BLASFloats + LAPACK_SVD_ALGS = ( + LAPACK_QRIteration(), + LAPACK_DivideAndConquer(), ) - U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 1)) - @test length(diagview(S1)) == 1 - @test diagview(S1) ≈ diagview(S)[1:1] - - U2, S2, V2ᴴ = svd_trunc_no_error(A; alg, trunc = trunc_fun(0.2, 3)) - @test length(diagview(S2)) == 2 - @test diagview(S2) ≈ diagview(S)[1:2] + TestSuite.test_svd(T, (m, n)) + TestSuite.test_svd_algs(T, (m, n), LAPACK_SVD_ALGS) + elseif T ∈ GenericFloats + TestSuite.test_svd(T, (m, n)) + TestSuite.test_svd_algs(T, (m, n), (GLA_QRIteration(),)) + end + if m == n + AT = Diagonal{T, Vector{T}} + TestSuite.test_svd(AT, m) + TestSuite.test_svd_algs(AT, m, (DiagonalAlgorithm(),)) end - end -end - -@testset "svd_trunc! specify truncation algorithm T = $T" for T in BLASFloats - rng = StableRNG(123) - atol = sqrt(eps(real(T))) - m = 4 - U = qr_compact(randn(rng, T, m, m))[1] - S = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) - Vᴴ = qr_compact(randn(rng, T, m, m))[1] - A = U * S * Vᴴ - alg = TruncatedAlgorithm(LAPACK_DivideAndConquer(), trunctol(; atol = 0.2)) - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg) - @test diagview(S2) ≈ diagview(S)[1:2] - @test ϵ2 ≈ norm(diagview(S)[3:4]) atol = atol - U2, S2, V2ᴴ = @constinferred svd_trunc_no_error(A; alg) - @test diagview(S2) ≈ diagview(S)[1:2] - @test_throws ArgumentError svd_trunc(A; alg, trunc = (; maxrank = 2)) - @test_throws ArgumentError svd_trunc_no_error(A; alg, trunc = (; maxrank = 2)) -end - -@testset "svd for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) - rng = StableRNG(123) - atol = sqrt(eps(real(T))) - for m in (54, 0) - Ad = randn(T, m) - A = Diagonal(Ad) - - U, S, Vᴴ = @constinferred svd_compact(A) - @test U isa AbstractMatrix{T} && size(U) == size(A) - @test Vᴴ isa AbstractMatrix{T} && size(Vᴴ) == size(A) - @test S isa Diagonal{real(T)} && size(S) == size(A) - @test isunitary(U) - @test isunitary(Vᴴ) - @test all(≥(0), diagview(S)) - @test A ≈ U * S * Vᴴ - - U, S, Vᴴ = @constinferred svd_full(A) - @test U isa AbstractMatrix{T} && size(U) == size(A) - @test Vᴴ isa AbstractMatrix{T} && size(Vᴴ) == size(A) - @test S isa Diagonal{real(T)} && size(S) == size(A) - @test isunitary(U) - @test isunitary(Vᴴ) - @test all(≥(0), diagview(S)) - @test A ≈ U * S * Vᴴ - - S2 = @constinferred svd_vals(A) - @test S2 isa AbstractVector{real(T)} && length(S2) == m - @test S2 ≈ diagview(S) - - alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - U3, S3, Vᴴ3, ϵ3 = @constinferred svd_trunc(A; alg) - @test diagview(S3) ≈ S2[1:min(m, 2)] - @test ϵ3 ≈ norm(S2[(min(m, 2) + 1):m]) atol = atol - U3, S3, Vᴴ3 = @constinferred svd_trunc_no_error(A; alg) - @test diagview(S3) ≈ S2[1:min(m, 2)] end end diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index 734d48fd..a0763c7f 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -76,6 +76,9 @@ isrightcomplete(Vᴴ, Nᴴ) = Vᴴ' * Vᴴ + Nᴴ' * Nᴴ ≈ I isrightcomplete(V::AnyCuMatrix, N::AnyCuMatrix) = isrightcomplete(collect(V), collect(N)) isrightcomplete(V::AnyROCMatrix, N::AnyROCMatrix) = isrightcomplete(collect(V), collect(N)) +instantiate_unitary(T, A, sz) = qr_compact(randn!(similar(A, eltype(T), sz, sz)))[1] +instantiate_unitary(::Type{<:Diagonal}, A, sz) = Diagonal(fill!(similar(parent(A), eltype(A), sz), one(eltype(A)))) + include("qr.jl") include("lq.jl") include("polar.jl") @@ -84,5 +87,6 @@ include("schur.jl") include("eig.jl") include("eigh.jl") include("orthnull.jl") +include("svd.jl") end diff --git a/test/testsuite/qr.jl b/test/testsuite/qr.jl index 8bba9e62..d5ebf402 100644 --- a/test/testsuite/qr.jl +++ b/test/testsuite/qr.jl @@ -1,20 +1,20 @@ using TestExtras -function test_qr(T::Type, sz; test_null = true, kwargs...) +function test_qr(T::Type, sz; kwargs...) summary_str = testargs_summary(T, sz) return @testset "qr $summary_str" begin test_qr_compact(T, sz; kwargs...) test_qr_full(T, sz; kwargs...) - test_null && test_qr_null(T, sz; kwargs...) + test_qr_null(T, sz; kwargs...) end end -function test_qr_algs(T::Type, sz, algs; test_null = true, kwargs...) +function test_qr_algs(T::Type, sz, algs; kwargs...) summary_str = testargs_summary(T, sz) return @testset "qr algorithms $summary_str" begin test_qr_compact_algs(T, sz, algs; kwargs...) test_qr_full_algs(T, sz, algs; kwargs...) - test_null && test_qr_null_algs(T, sz, algs; kwargs...) + test_qr_null_algs(T, sz, algs; kwargs...) end end diff --git a/test/testsuite/svd.jl b/test/testsuite/svd.jl new file mode 100644 index 00000000..d1d8ca33 --- /dev/null +++ b/test/testsuite/svd.jl @@ -0,0 +1,312 @@ +using TestExtras +using GenericLinearAlgebra +using LinearAlgebra: opnorm + +function test_svd(T::Type, sz; test_trunc = true, kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "svd $summary_str" begin + test_svd_compact(T, sz; kwargs...) + test_svd_full(T, sz; kwargs...) + test_trunc && test_svd_trunc(T, sz; kwargs...) + end +end + +function test_svd_algs(T::Type, sz, algs; test_trunc = true, kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "svd algorithms $summary_str" begin + test_svd_compact_algs(T, sz, algs; kwargs...) + test_svd_full_algs(T, sz, algs; kwargs...) + test_trunc && test_svd_trunc_algs(T, sz, algs; kwargs...) + end +end + +function test_svd_compact( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(eltype(T)), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "svd_compact! $summary_str" begin + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + m, n = size(A) + minmn = min(m, n) + U, S, Vᴴ = @testinferred svd_compact(A) + @test size(U) == (m, minmn) + @test S isa Diagonal{real(eltype(T))} && size(S) == (minmn, minmn) + @test size(Vᴴ) == (minmn, n) + @test U * S * Vᴴ ≈ A + @test isisometric(U) + @test isisometric(Vᴴ; side = :right) + @test isposdef(S) + + Sc = similar(A, real(eltype(T)), min(m, n)) + U2, S2, V2ᴴ = @testinferred svd_compact!(Ac, (U, S, Vᴴ)) + @test U2 * S2 * V2ᴴ ≈ A + @test isisometric(U2) + @test isisometric(V2ᴴ; side = :right) + @test isposdef(S2) + + Sd = @testinferred svd_vals(A) + @test S ≈ Diagonal(Sd) + end +end + +function test_svd_compact_algs( + T::Type, sz, algs; + atol::Real = 0, rtol::Real = precision(eltype(T)), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "svd_compact! algorithm $alg $summary_str" for alg in algs + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + m, n = size(A) + minmn = min(m, n) + U, S, Vᴴ = @testinferred svd_compact(A; alg) + @test size(U) == (m, minmn) + @test S isa Diagonal{real(eltype(T))} && size(S) == (minmn, minmn) + @test size(Vᴴ) == (minmn, n) + @test U * S * Vᴴ ≈ A + @test isisometric(U) + @test isisometric(Vᴴ; side = :right) + @test isposdef(S) + + U2, S2, V2ᴴ = @testinferred svd_compact!(Ac, (U, S, Vᴴ); alg) + @test U2 * S2 * V2ᴴ ≈ A + @test isisometric(U2) + @test isisometric(V2ᴴ; side = :right) + @test isposdef(S2) + + Sd = @testinferred svd_vals(A; alg) + @test S ≈ Diagonal(Sd) + end +end + +function test_svd_full( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(eltype(T)), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "svd_full! $summary_str" begin + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + m, n = size(A) + minmn = min(m, n) + + U, S, Vᴴ = @testinferred svd_full(A) + @test size(U) == (m, m) + @test eltype(S) == real(eltype(T)) && size(S) == (m, n) + @test size(Vᴴ) == (n, n) + @test U * S * Vᴴ ≈ A + @test isunitary(U) + @test isunitary(Vᴴ) + @test all(isposdef, diagview(S)) + + U2, S2, V2ᴴ = @testinferred svd_full!(Ac, (U, S, Vᴴ)) + @test U2 * S2 * V2ᴴ ≈ A + @test isunitary(U2) + @test isunitary(V2ᴴ) + @test all(isposdef, diagview(S2)) + + Sc = similar(A, real(eltype(T)), min(m, n)) + Sc2 = @testinferred svd_vals!(copy!(Ac, A), Sc) + @test collect(diagview(S)) ≈ collect(Sc2) + end +end + +function test_svd_full_algs( + T::Type, sz, algs; + atol::Real = 0, rtol::Real = precision(eltype(T)), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "svd_full! algorithm $alg $summary_str" for alg in algs + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + m, n = size(A) + minmn = min(m, n) + + U, S, Vᴴ = @testinferred svd_full(A; alg) + @test size(U) == (m, m) + @test eltype(S) == real(eltype(T)) && size(S) == (m, n) + @test size(Vᴴ) == (n, n) + @test U * S * Vᴴ ≈ A + @test isunitary(U) + @test isunitary(Vᴴ) + @test all(isposdef, diagview(S)) + + U2, S2, V2ᴴ = @testinferred svd_full!(Ac, (U, S, Vᴴ); alg) + @test U2 * S2 * V2ᴴ ≈ A + @test isunitary(U2) + @test isunitary(V2ᴴ) + @test all(isposdef, diagview(S2)) + + Sc = similar(A, real(eltype(T)), min(m, n)) + Sc2 = @testinferred svd_vals!(copy!(Ac, A), Sc; alg) + @test collect(diagview(S)) ≈ collect(Sc2) + end +end + +function test_svd_trunc( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(eltype(T)), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "svd_trunc! $summary_str" begin + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + m, n = size(A) + minmn = min(m, n) + S₀ = svd_vals(A) + r = minmn - 2 + + if m > 0 && n > 0 + U1, S1, V1ᴴ, ϵ1 = @testinferred svd_trunc(A; trunc = truncrank(r)) + @test length(diagview(S1)) == r + @test diagview(S1) ≈ S₀[1:r] + @test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] + # Test truncation error + @test ϵ1 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + + s = 1 + sqrt(eps(real(eltype(T)))) + trunc = trunctol(; atol = s * S₀[r + 1]) + + U2, S2, V2ᴴ, ϵ2 = @testinferred svd_trunc(A; trunc) + @test length(diagview(S2)) == r + @test U1 ≈ U2 + @test S1 ≈ S2 + @test V1ᴴ ≈ V2ᴴ + @test ϵ2 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + + trunc = truncerror(; atol = s * norm(@view(S₀[(r + 1):end]))) + U3, S3, V3ᴴ, ϵ3 = @testinferred svd_trunc(A; trunc) + @test length(diagview(S3)) == r + @test U1 ≈ U3 + @test S1 ≈ S3 + @test V1ᴴ ≈ V3ᴴ + @test ϵ3 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + end + + @testset "mix maxrank and tol" begin + m4 = 4 + U = instantiate_unitary(T, A, m4) + Sdiag = similar(A, real(eltype(T)), m4) + copyto!(Sdiag, [0.9, 0.3, 0.1, 0.01]) + S = Diagonal(Sdiag) + Vᴴ = instantiate_unitary(T, A, m4) + A = U * S * Vᴴ + for trunc_fun in ( + (rtol, maxrank) -> (; rtol, maxrank), + (rtol, maxrank) -> truncrank(maxrank) & trunctol(; rtol), + ) + U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; trunc = trunc_fun(0.2, 1)) + @test length(diagview(S1)) == 1 + @test diagview(S1) ≈ diagview(S)[1:1] + + U2, S2, V2ᴴ = svd_trunc_no_error(A; trunc = trunc_fun(0.2, 3)) + @test length(diagview(S2)) == 2 + @test diagview(S2) ≈ diagview(S)[1:2] + end + end + @testset "specify truncation algorithm" begin + atol = sqrt(eps(real(eltype(T)))) + m4 = 4 + U = instantiate_unitary(T, A, m4) + Sdiag = similar(A, real(eltype(T)), m4) + copyto!(Sdiag, [0.9, 0.3, 0.1, 0.01]) + Vᴴ = instantiate_unitary(T, A, m4) + S = Diagonal(Sdiag) + A = U * S * Vᴴ + alg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), trunctol(; atol = 0.2)) + U2, S2, V2ᴴ, ϵ2 = @testinferred svd_trunc(A; alg) + @test diagview(S2) ≈ diagview(S)[1:2] + @test ϵ2 ≈ norm(diagview(S)[3:4]) atol = atol + @test_throws ArgumentError svd_trunc(A; alg, trunc = (; maxrank = 2)) + @test_throws ArgumentError svd_trunc_no_error(A; alg, trunc = (; maxrank = 2)) + end + end +end + +function test_svd_trunc_algs( + T::Type, sz, algs; + atol::Real = 0, rtol::Real = precision(eltype(T)), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "svd_trunc! algorithm $alg $summary_str" for alg in algs + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + m, n = size(A) + minmn = min(m, n) + S₀ = svd_vals(A) + r = minmn - 2 + + if m > 0 && n > 0 + U1, S1, V1ᴴ, ϵ1 = @testinferred svd_trunc(A; trunc = truncrank(r), alg) + @test length(diagview(S1)) == r + @test diagview(S1) ≈ S₀[1:r] + @test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] + # Test truncation error + @test ϵ1 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + + s = 1 + sqrt(eps(real(eltype(T)))) + trunc = trunctol(; atol = s * S₀[r + 1]) + + U2, S2, V2ᴴ, ϵ2 = @testinferred svd_trunc(A; trunc, alg) + @test length(diagview(S2)) == r + @test U1 ≈ U2 + @test S1 ≈ S2 + @test V1ᴴ ≈ V2ᴴ + @test ϵ2 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + + trunc = truncerror(; atol = s * norm(@view(S₀[(r + 1):end]))) + U3, S3, V3ᴴ, ϵ3 = @testinferred svd_trunc(A; trunc, alg) + @test length(diagview(S3)) == r + @test U1 ≈ U3 + @test S1 ≈ S3 + @test V1ᴴ ≈ V3ᴴ + @test ϵ3 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + end + + @testset "mix maxrank and tol" begin + m4 = 4 + U = instantiate_unitary(T, A, m4) + Sdiag = similar(A, real(eltype(T)), m4) + copyto!(Sdiag, real(eltype(T))[0.9, 0.3, 0.1, 0.01]) + S = Diagonal(Sdiag) + Vᴴ = instantiate_unitary(T, A, m4) + A = U * S * Vᴴ + for trunc_fun in ( + (rtol, maxrank) -> (; rtol, maxrank), + (rtol, maxrank) -> truncrank(maxrank) & trunctol(; rtol), + ) + U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; trunc = trunc_fun(0.2, 1), alg) + @test length(diagview(S1)) == 1 + @test diagview(S1) ≈ diagview(S)[1:1] + + U2, S2, V2ᴴ, ϵ2 = svd_trunc(A; trunc = trunc_fun(0.2, 3), alg) + @test length(diagview(S2)) == 2 + @test diagview(S2) ≈ diagview(S)[1:2] + end + end + @testset "specify truncation algorithm" begin + atol = sqrt(eps(real(eltype(T)))) + m4 = 4 + U = instantiate_unitary(T, A, m4) + Sdiag = similar(A, real(eltype(T)), m4) + copyto!(Sdiag, real(eltype(T))[0.9, 0.3, 0.1, 0.01]) + S = Diagonal(Sdiag) + Vᴴ = instantiate_unitary(T, A, m4) + A = U * S * Vᴴ + truncalg = TruncatedAlgorithm(alg, trunctol(; atol = 0.2)) + U2, S2, V2ᴴ, ϵ2 = @testinferred svd_trunc(A; alg = truncalg) + @test diagview(S2) ≈ diagview(S)[1:2] + @test ϵ2 ≈ norm(diagview(S)[3:4]) atol = atol + @test_throws ArgumentError svd_trunc(A; alg = truncalg, trunc = (; maxrank = 2)) + @test_throws ArgumentError svd_trunc_no_error(A; alg = truncalg, trunc = (; maxrank = 2)) + end + end +end