diff --git a/Project.toml b/Project.toml index 78702565..254377c9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MatrixAlgebraKit" uuid = "6c742aac-3347-4629-af66-fc926824e5e4" authors = ["Jutho and contributors"] -version = "0.1.1" +version = "0.1.2" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index f135df10..6088ed7e 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -11,10 +11,17 @@ function TruncationStrategy(; atol=nothing, rtol=nothing, maxrank=nothing) if isnothing(maxrank) && isnothing(atol) && isnothing(rtol) return NoTruncation() elseif isnothing(maxrank) - @assert isnothing(rtol) "TODO: rtol" - return trunctol(atol) + atol = @something atol 0 + rtol = @something rtol 0 + return TruncationKeepAbove(atol, rtol) else - return truncrank(maxrank) + if isnothing(atol) && isnothing(rtol) + return truncrank(maxrank) + else + atol = @something atol 0 + rtol = @something rtol 0 + return truncrank(maxrank) & TruncationKeepAbove(atol, rtol) + end end end @@ -82,6 +89,28 @@ Truncation strategy to discard the values that are larger than `atol` in absolut """ truncabove(atol) = TruncationKeepFiltered(≤(atol) ∘ abs) +""" + TruncationIntersection(trunc1::TruncationStrategy, trunc2::TruncationStrategy) + +Compose two truncation strategies, keeping values common between the two strategies. +""" +struct TruncationIntersection{T<:Tuple{Vararg{TruncationStrategy}}} <: + TruncationStrategy + components::T +end +function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy) + return TruncationIntersection((trunc1, trunc2)) +end +function Base.:&(trunc1::TruncationIntersection, trunc2::TruncationIntersection) + return TruncationIntersection((trunc1.components..., trunc2.components...)) +end +function Base.:&(trunc1::TruncationIntersection, trunc2::TruncationStrategy) + return TruncationIntersection((trunc1.components..., trunc2)) +end +function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationIntersection) + return TruncationIntersection((trunc1, trunc2.components...)) +end + # truncate! # --------- # Generic implementation: `findtruncated` followed by indexing @@ -147,6 +176,11 @@ function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove) return 1:i end +function findtruncated(values::AbstractVector, strategy::TruncationIntersection) + inds = map(Base.Fix1(findtruncated, values), strategy.components) + return intersect(inds...) +end + """ TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm) diff --git a/test/runtests.jl b/test/runtests.jl index 541c7da8..7ad7c216 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,8 @@ using SafeTestsets +@safetestset "Truncate" begin + include("truncate.jl") +end @safetestset "QR / LQ Decomposition" begin include("qr.jl") include("lq.jl") diff --git a/test/svd.jl b/test/svd.jl index b92c2e45..e40a69e9 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -3,7 +3,7 @@ using Test using TestExtras using StableRNGs using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef -using MatrixAlgebraKit: diagview +using MatrixAlgebraKit: TruncationKeepAbove, diagview @testset "svd_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) @@ -115,3 +115,33 @@ end end end end + +@testset "svd_trunc! mix maxrank and tol for T = $T" for T in + (Float32, Float64, ComplexF32, + ComplexF64) + rng = StableRNG(123) + if LinearAlgebra.LAPACK.version() < v"3.12.0" + algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) + else + algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), + LAPACK_Jacobi()) + end + m = 4 + @testset "algorithm $alg" for alg in algs + U = qr_compact(randn(rng, T, m, m))[1] + S = Diagonal([0.9, 0.3, 0.1, 0.01]) + Vᴴ = qr_compact(randn(rng, T, m, m))[1] + A = U * S * Vᴴ + + for trunc_fun in ((rtol, maxrank) -> (; rtol, maxrank), + (rtol, maxrank) -> truncrank(maxrank) & TruncationKeepAbove(0, rtol)) + U1, S1, V1ᴴ = svd_trunc(A; alg, trunc=trunc_fun(0.2, 1)) + @test length(S1.diag) == 1 + @test S1.diag ≈ S.diag[1:1] rtol = sqrt(eps(real(T))) + + U2, S2, V2ᴴ = svd_trunc(A; alg, trunc=trunc_fun(0.2, 3)) + @test length(S2.diag) == 2 + @test S2.diag ≈ S.diag[1:2] rtol = sqrt(eps(real(T))) + end + end +end diff --git a/test/truncate.jl b/test/truncate.jl new file mode 100644 index 00000000..5004a39d --- /dev/null +++ b/test/truncate.jl @@ -0,0 +1,29 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbove, + TruncationStrategy + +@testset "truncate" begin + trunc = @constinferred TruncationStrategy() + @test trunc isa NoTruncation + + trunc = @constinferred TruncationStrategy(; atol=1e-2, rtol=1e-3) + @test trunc isa TruncationKeepAbove + @test trunc == TruncationKeepAbove(1e-2, 1e-3) + @test trunc.atol == 1e-2 + @test trunc.rtol == 1e-3 + + trunc = @constinferred TruncationStrategy(; maxrank=10) + @test trunc isa TruncationKeepSorted + @test trunc == truncrank(10) + @test trunc.howmany == 10 + @test trunc.sortby == abs + @test trunc.rev == true + + trunc = @constinferred TruncationStrategy(; atol=1e-2, rtol=1e-3, maxrank=10) + @test trunc isa TruncationIntersection + @test trunc == truncrank(10) & TruncationKeepAbove(1e-2, 1e-3) + @test trunc.components[1] == truncrank(10) + @test trunc.components[2] == TruncationKeepAbove(1e-2, 1e-3) +end