From 86d7f0c721866e7c608facc9d039c57ac61fae7d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 9 Apr 2025 12:23:48 -0400 Subject: [PATCH 1/4] Truncation composition --- Project.toml | 2 +- src/implementations/truncation.jl | 39 ++++++++++++++++++++++++++++--- test/svd.jl | 27 +++++++++++++++++++++ 3 files changed, 64 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 78702565..254377c9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MatrixAlgebraKit" uuid = "6c742aac-3347-4629-af66-fc926824e5e4" authors = ["Jutho and contributors"] -version = "0.1.1" +version = "0.1.2" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index f135df10..4173fe05 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -11,10 +11,17 @@ function TruncationStrategy(; atol=nothing, rtol=nothing, maxrank=nothing) if isnothing(maxrank) && isnothing(atol) && isnothing(rtol) return NoTruncation() elseif isnothing(maxrank) - @assert isnothing(rtol) "TODO: rtol" - return trunctol(atol) + atol = @something atol 0 + rtol = @something rtol 0 + return TruncationKeepAbove(atol, rtol) else - return truncrank(maxrank) + if isnothing(atol) && isnothing(rtol) + return truncrank(maxrank) + else + atol = @something atol 0 + rtol = @something rtol 0 + return truncrank(maxrank) & TruncationKeepAbove(atol, rtol) + end end end @@ -82,6 +89,27 @@ Truncation strategy to discard the values that are larger than `atol` in absolut """ truncabove(atol) = TruncationKeepFiltered(≤(atol) ∘ abs) +""" + TruncationComposition(trunc1::TruncationStrategy, trunc2::TruncationStrategy) +Compose two truncation strategies, keeping values common between the two strategies. +""" +struct TruncationComposition{T<:Tuple{Vararg{TruncationStrategy}}} <: + TruncationStrategy + components::T +end +function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy) + return TruncationComposition((trunc1, trunc2)) +end +function Base.:&(trunc1::TruncationComposition, trunc2::TruncationComposition) + return TruncationComposition((trunc1.components..., trunc2.components...)) +end +function Base.:&(trunc1::TruncationComposition, trunc2::TruncationStrategy) + return TruncationComposition((trunc1.components..., trunc2)) +end +function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationComposition) + return TruncationComposition((trunc1, trunc2.components...)) +end + # truncate! # --------- # Generic implementation: `findtruncated` followed by indexing @@ -147,6 +175,11 @@ function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove) return 1:i end +function findtruncated(values::AbstractVector, strategy::TruncationComposition) + inds = map(Base.Fix1(findtruncated, values), strategy.components) + return intersect(inds...) +end + """ TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm) diff --git a/test/svd.jl b/test/svd.jl index b92c2e45..ee4b74e0 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -115,3 +115,30 @@ end end end end + +@testset "svd_trunc! mix maxrank and tol for T = $T" for T in + (Float32, Float64, ComplexF32, + ComplexF64) + rng = StableRNG(123) + if LinearAlgebra.LAPACK.version() < v"3.12.0" + algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) + else + algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), + LAPACK_Jacobi()) + end + m = 4 + @testset "algorithm $alg" for alg in algs + U = qr_compact(randn(rng, T, m, m))[1] + S = Diagonal([0.9, 0.3, 0.1, 0.01]) + Vᴴ = qr_compact(randn(rng, T, m, m))[1] + A = U * S * Vᴴ + + U1, S1, V1ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=1)) + @test length(S1.diag) == 1 + @test S1.diag ≈ S.diag[1:1] rtol = sqrt(eps(real(T))) + + U2, S2, V2ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=3)) + @test length(S2.diag) == 2 + @test S2.diag ≈ S.diag[1:2] rtol = sqrt(eps(real(T))) + end +end From 48ec2f02e9a670559ed62dd480f45d9e4f1f3840 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 9 Apr 2025 14:14:26 -0400 Subject: [PATCH 2/4] Add tests for truncation objects --- test/runtests.jl | 3 +++ test/svd.jl | 19 ++++++++++++------- test/truncate.jl | 29 +++++++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 7 deletions(-) create mode 100644 test/truncate.jl diff --git a/test/runtests.jl b/test/runtests.jl index 541c7da8..7ad7c216 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,8 @@ using SafeTestsets +@safetestset "Truncate" begin + include("truncate.jl") +end @safetestset "QR / LQ Decomposition" begin include("qr.jl") include("lq.jl") diff --git a/test/svd.jl b/test/svd.jl index ee4b74e0..380ed180 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -3,7 +3,7 @@ using Test using TestExtras using StableRNGs using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef -using MatrixAlgebraKit: diagview +using MatrixAlgebraKit: TruncationKeepAbove, diagview @testset "svd_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) @@ -133,12 +133,17 @@ end Vᴴ = qr_compact(randn(rng, T, m, m))[1] A = U * S * Vᴴ - U1, S1, V1ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=1)) - @test length(S1.diag) == 1 - @test S1.diag ≈ S.diag[1:1] rtol = sqrt(eps(real(T))) + for (rtol, maxrank) in ((0.2, 1), (0.2, 3)) + for trunc in ((; rtol, maxrank), + truncrank(maxrank) & TruncationKeepAbove(0, rtol)) + U1, S1, V1ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=1)) + @test length(S1.diag) == 1 + @test S1.diag ≈ S.diag[1:1] rtol = sqrt(eps(real(T))) - U2, S2, V2ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=3)) - @test length(S2.diag) == 2 - @test S2.diag ≈ S.diag[1:2] rtol = sqrt(eps(real(T))) + U2, S2, V2ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=3)) + @test length(S2.diag) == 2 + @test S2.diag ≈ S.diag[1:2] rtol = sqrt(eps(real(T))) + end + end end end diff --git a/test/truncate.jl b/test/truncate.jl new file mode 100644 index 00000000..f0cb49a8 --- /dev/null +++ b/test/truncate.jl @@ -0,0 +1,29 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using MatrixAlgebraKit: NoTruncation, TruncationComposition, TruncationKeepAbove, + TruncationStrategy + +@testset "truncate" begin + trunc = @constinferred TruncationStrategy() + @test trunc isa NoTruncation + + trunc = @constinferred TruncationStrategy(; atol=1e-2, rtol=1e-3) + @test trunc isa TruncationKeepAbove + @test trunc == TruncationKeepAbove(1e-2, 1e-3) + @test trunc.atol == 1e-2 + @test trunc.rtol == 1e-3 + + trunc = @constinferred TruncationStrategy(; maxrank=10) + @test trunc isa TruncationKeepSorted + @test trunc == truncrank(10) + @test trunc.howmany == 10 + @test trunc.sortby == abs + @test trunc.rev == true + + trunc = @constinferred TruncationStrategy(; atol=1e-2, rtol=1e-3, maxrank=10) + @test trunc isa TruncationComposition + @test trunc == truncrank(10) & TruncationKeepAbove(1e-2, 1e-3) + @test trunc.components[1] == truncrank(10) + @test trunc.components[2] == TruncationKeepAbove(1e-2, 1e-3) +end From ea0aa6e6c62114b6622c354615635ce451d03583 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 9 Apr 2025 16:28:34 -0400 Subject: [PATCH 3/4] Fix truncated SVD test logic --- test/svd.jl | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/test/svd.jl b/test/svd.jl index 380ed180..e40a69e9 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -133,17 +133,15 @@ end Vᴴ = qr_compact(randn(rng, T, m, m))[1] A = U * S * Vᴴ - for (rtol, maxrank) in ((0.2, 1), (0.2, 3)) - for trunc in ((; rtol, maxrank), - truncrank(maxrank) & TruncationKeepAbove(0, rtol)) - U1, S1, V1ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=1)) - @test length(S1.diag) == 1 - @test S1.diag ≈ S.diag[1:1] rtol = sqrt(eps(real(T))) + for trunc_fun in ((rtol, maxrank) -> (; rtol, maxrank), + (rtol, maxrank) -> truncrank(maxrank) & TruncationKeepAbove(0, rtol)) + U1, S1, V1ᴴ = svd_trunc(A; alg, trunc=trunc_fun(0.2, 1)) + @test length(S1.diag) == 1 + @test S1.diag ≈ S.diag[1:1] rtol = sqrt(eps(real(T))) - U2, S2, V2ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=3)) - @test length(S2.diag) == 2 - @test S2.diag ≈ S.diag[1:2] rtol = sqrt(eps(real(T))) - end + U2, S2, V2ᴴ = svd_trunc(A; alg, trunc=trunc_fun(0.2, 3)) + @test length(S2.diag) == 2 + @test S2.diag ≈ S.diag[1:2] rtol = sqrt(eps(real(T))) end end end From 3d12ca493bd3c37fe71cbceb0e7823f8daa7ff27 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 17 Apr 2025 09:30:56 -0400 Subject: [PATCH 4/4] Change name to TruncationIntersection --- src/implementations/truncation.jl | 21 +++++++++++---------- test/truncate.jl | 4 ++-- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 4173fe05..6088ed7e 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -90,24 +90,25 @@ Truncation strategy to discard the values that are larger than `atol` in absolut truncabove(atol) = TruncationKeepFiltered(≤(atol) ∘ abs) """ - TruncationComposition(trunc1::TruncationStrategy, trunc2::TruncationStrategy) + TruncationIntersection(trunc1::TruncationStrategy, trunc2::TruncationStrategy) + Compose two truncation strategies, keeping values common between the two strategies. """ -struct TruncationComposition{T<:Tuple{Vararg{TruncationStrategy}}} <: +struct TruncationIntersection{T<:Tuple{Vararg{TruncationStrategy}}} <: TruncationStrategy components::T end function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy) - return TruncationComposition((trunc1, trunc2)) + return TruncationIntersection((trunc1, trunc2)) end -function Base.:&(trunc1::TruncationComposition, trunc2::TruncationComposition) - return TruncationComposition((trunc1.components..., trunc2.components...)) +function Base.:&(trunc1::TruncationIntersection, trunc2::TruncationIntersection) + return TruncationIntersection((trunc1.components..., trunc2.components...)) end -function Base.:&(trunc1::TruncationComposition, trunc2::TruncationStrategy) - return TruncationComposition((trunc1.components..., trunc2)) +function Base.:&(trunc1::TruncationIntersection, trunc2::TruncationStrategy) + return TruncationIntersection((trunc1.components..., trunc2)) end -function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationComposition) - return TruncationComposition((trunc1, trunc2.components...)) +function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationIntersection) + return TruncationIntersection((trunc1, trunc2.components...)) end # truncate! @@ -175,7 +176,7 @@ function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove) return 1:i end -function findtruncated(values::AbstractVector, strategy::TruncationComposition) +function findtruncated(values::AbstractVector, strategy::TruncationIntersection) inds = map(Base.Fix1(findtruncated, values), strategy.components) return intersect(inds...) end diff --git a/test/truncate.jl b/test/truncate.jl index f0cb49a8..5004a39d 100644 --- a/test/truncate.jl +++ b/test/truncate.jl @@ -1,7 +1,7 @@ using MatrixAlgebraKit using Test using TestExtras -using MatrixAlgebraKit: NoTruncation, TruncationComposition, TruncationKeepAbove, +using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbove, TruncationStrategy @testset "truncate" begin @@ -22,7 +22,7 @@ using MatrixAlgebraKit: NoTruncation, TruncationComposition, TruncationKeepAbove @test trunc.rev == true trunc = @constinferred TruncationStrategy(; atol=1e-2, rtol=1e-3, maxrank=10) - @test trunc isa TruncationComposition + @test trunc isa TruncationIntersection @test trunc == truncrank(10) & TruncationKeepAbove(1e-2, 1e-3) @test trunc.components[1] == truncrank(10) @test trunc.components[2] == TruncationKeepAbove(1e-2, 1e-3)