From 5adbbfff94ea1c4813c8cd1efebe98aeea684a86 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 9 Apr 2025 09:04:04 -0400 Subject: [PATCH 1/7] Truncation composition --- ext/MatrixAlgebraKitChainRulesCoreExt.jl | 2 +- src/algorithms.jl | 2 +- src/implementations/orthnull.jl | 2 +- src/implementations/truncation.jl | 33 +++++++++++++++++++++--- src/interface/orthnull.jl | 2 +- src/pullbacks/polar.jl | 2 +- test/chainrules.jl | 2 +- test/orthnull.jl | 2 +- test/svd.jl | 26 +++++++++++++++++++ 9 files changed, 63 insertions(+), 10 deletions(-) diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index 5ef1fd9b..e21842c6 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -176,4 +176,4 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A::AbstractMatrix, PWᴴ, return PWᴴ, right_polar_pullback end -end \ No newline at end of file +end diff --git a/src/algorithms.jl b/src/algorithms.jl index ea260aa2..fc247317 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -178,4 +178,4 @@ macro check_size(x, sz, size=:size) string($sz) szx == $sz || throw(DimensionMismatch($err)) end) -end \ No newline at end of file +end diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index e134121b..c34df721 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -215,4 +215,4 @@ function right_null!(A::AbstractMatrix, Nᴴ; kwargs...) else throw(ArgumentError("`right_null!` received unknown value `kind = $kind`")) end -end \ No newline at end of file +end diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index f135df10..a2c61e11 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -11,10 +11,17 @@ function TruncationStrategy(; atol=nothing, rtol=nothing, maxrank=nothing) if isnothing(maxrank) && isnothing(atol) && isnothing(rtol) return NoTruncation() elseif isnothing(maxrank) - @assert isnothing(rtol) "TODO: rtol" - return trunctol(atol) + atol = @something atol 0 + rtol = @something rtol 0 + return TruncationKeepAbove(atol, rtol) else - return truncrank(maxrank) + if isnothing(atol) && isnothing(rtol) + return truncrank(maxrank) + else + atol = @something atol 0 + rtol = @something rtol 0 + return truncrank(maxrank) & TruncationKeepAbove(atol, rtol) + end end end @@ -82,6 +89,20 @@ Truncation strategy to discard the values that are larger than `atol` in absolut """ truncabove(atol) = TruncationKeepFiltered(≤(atol) ∘ abs) +""" + TruncationComposition(trunc1::TruncationStrategy, trunc2::TruncationStrategy) + +Compose two truncation strategies, keeping values common between the two strategies. +""" +struct TruncationComposition{T1<:TruncationStrategy,T2<:TruncationStrategy} <: + TruncationStrategy + trunc1::T1 + trunc2::T2 +end +function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy) + return TruncationComposition(trunc1, trunc2) +end + # truncate! # --------- # Generic implementation: `findtruncated` followed by indexing @@ -147,6 +168,12 @@ function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove) return 1:i end +function findtruncated(values::AbstractVector, strategy::TruncationComposition) + ind1 = findtruncated(values, strategy.trunc1) + ind2 = findtruncated(values, strategy.trunc2) + return ind1 ∩ ind2 +end + """ TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm) diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index de7c3153..78a60568 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -224,4 +224,4 @@ function right_null!(A::AbstractMatrix; kwargs...) end function right_null(A::AbstractMatrix; kwargs...) return right_null!(copy_input(right_null, A); kwargs...) -end \ No newline at end of file +end diff --git a/src/pullbacks/polar.jl b/src/pullbacks/polar.jl index ed6fc17d..2eea389e 100644 --- a/src/pullbacks/polar.jl +++ b/src/pullbacks/polar.jl @@ -58,4 +58,4 @@ function right_polar_pullback!(ΔA::AbstractMatrix, PWᴴ, ΔPWᴴ) ΔA .+= PΔWᴴ end return ΔA -end \ No newline at end of file +end diff --git a/test/chainrules.jl b/test/chainrules.jl index 47de1c04..a93c5270 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -356,4 +356,4 @@ end test_rrule(config, right_null, A; fkwargs=(; kind=:lqpos), output_tangent=ΔNᴴ, atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) end -end \ No newline at end of file +end diff --git a/test/orthnull.jl b/test/orthnull.jl index 7b926b59..7784fc82 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -209,4 +209,4 @@ end end end end -end \ No newline at end of file +end diff --git a/test/svd.jl b/test/svd.jl index 974957f0..8251a2b0 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -108,3 +108,29 @@ end end end end + +@testset "svd_trunc! mix maxrank and tol for T = $T" for T in (Float32, Float64, ComplexF32, + ComplexF64) + rng = StableRNG(123) + if LinearAlgebra.LAPACK.version() < v"3.12.0" + algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) + else + algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), + LAPACK_Jacobi()) + end + m = 4 + @testset "algorithm $alg" for alg in algs + U = qr_compact(randn(rng, T, m, m))[1] + S = Diagonal([0.9, 0.3, 0.1, 0.01]) + Vᴴ = qr_compact(randn(rng, T, m, m))[1] + A = U * S * Vᴴ + + U1, S1, V1ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=1)) + @test length(S1.diag) == 1 + @test S1.diag ≈ S.diag[1:1] rtol = sqrt(eps(real(T))) + + U2, S2, V2ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=3)) + @test length(S2.diag) == 2 + @test S2.diag ≈ S.diag[1:2] rtol = sqrt(eps(real(T))) + end +end From c691803161477948462ae9e9d0665ad0520f391a Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Tue, 8 Apr 2025 17:38:44 -0400 Subject: [PATCH 2/7] Fix some orthnull cross references --- src/interface/orthnull.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 78a60568..236184bd 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -62,7 +62,7 @@ of `kind`. the input matrix `A`. Always use the return value of the function as it may not always be possible to use the provided `CV` as output. -See also [`right_orth(!)`](@ref right_orth), [`left_orth(!)`](@ref left_orth), [`right_orth(!)`](@ref right_orth) +See also [`right_orth(!)`](@ref right_orth), [`left_null(!)`](@ref left_null), [`right_null(!)`](@ref right_null) """ function left_orth end function left_orth! end @@ -117,7 +117,7 @@ of `kind`. the input matrix `A`. Always use the return value of the function as it may not always be possible to use the provided `CVᴴ` as output. -See also [`left_orth(!)`](@ref left_orth), [`left_orth(!)`](@ref left_orth), [`right_orth(!)`](@ref right_orth) +See also [`left_orth(!)`](@ref left_orth), [`left_null(!)`](@ref left_null), [`right_null(!)`](@ref right_null) """ function right_orth end function right_orth! end From 9707ab1d31e3a717673eb355058baa058a3bce71 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Wed, 9 Apr 2025 10:34:08 -0400 Subject: [PATCH 3/7] Apply JuliaFormatter v2 formatting changes (#16) * Format * Add vim temp files to gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 4f2cc4b7..73664508 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ *.jl.*.cov *.jl.cov *.jl.mem +.*.swp Manifest.toml docs/build/ From 77cc6c31e13ec88f0faa36d5a7bbe52b38b6ffb8 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 9 Apr 2025 11:52:16 -0400 Subject: [PATCH 4/7] Flatten composition of compositions --- src/implementations/truncation.jl | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index a2c61e11..73b59c97 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -94,13 +94,21 @@ truncabove(atol) = TruncationKeepFiltered(≤(atol) ∘ abs) Compose two truncation strategies, keeping values common between the two strategies. """ -struct TruncationComposition{T1<:TruncationStrategy,T2<:TruncationStrategy} <: +struct TruncationComposition{T<:Tuple{Vararg{TruncationStrategy}}} <: TruncationStrategy - trunc1::T1 - trunc2::T2 + components::T end function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy) - return TruncationComposition(trunc1, trunc2) + return TruncationComposition((trunc1, trunc2)) +end +function Base.:&(trunc1::TruncationComposition, trunc2::TruncationComposition) + return TruncationComposition((trunc1.components..., trunc2.components...)) +end +function Base.:&(trunc1::TruncationComposition, trunc2::TruncationStrategy) + return TruncationComposition((trunc1.components..., trunc2)) +end +function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationComposition) + return TruncationComposition((trunc1, trunc2.components...)) end # truncate! @@ -169,9 +177,8 @@ function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove) end function findtruncated(values::AbstractVector, strategy::TruncationComposition) - ind1 = findtruncated(values, strategy.trunc1) - ind2 = findtruncated(values, strategy.trunc2) - return ind1 ∩ ind2 + inds = map(Base.Fix1(findtruncated, values), strategy.components) + return intersect(inds...) end """ From f56525c857a6d9797f8710faaa71b0c6b692c8a6 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 9 Apr 2025 11:57:12 -0400 Subject: [PATCH 5/7] Format --- test/svd.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/svd.jl b/test/svd.jl index 8251a2b0..ed7b502b 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -109,8 +109,9 @@ end end end -@testset "svd_trunc! mix maxrank and tol for T = $T" for T in (Float32, Float64, ComplexF32, - ComplexF64) +@testset "svd_trunc! mix maxrank and tol for T = $T" for T in + (Float32, Float64, ComplexF32, + ComplexF64) rng = StableRNG(123) if LinearAlgebra.LAPACK.version() < v"3.12.0" algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) From 6c5c1fd15f0bebecf87ec8fc41616a2e929c755b Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Wed, 9 Apr 2025 12:06:30 -0400 Subject: [PATCH 6/7] Use SafeTestsets.jl in tests (#17) * Use SafeTestsets in tests * Use SafeTestsets in tests * Cleanup * Fix missing import --- Project.toml | 4 +++- test/chainrules.jl | 6 +++++- test/eig.jl | 7 +++++++ test/eigh.jl | 7 +++++++ test/lq.jl | 6 ++++++ test/orthnull.jl | 6 ++++++ test/polar.jl | 5 +++++ test/qr.jl | 6 ++++++ test/runtests.jl | 36 ++++++++++++++++-------------------- test/schur.jl | 6 ++++++ test/svd.jl | 7 +++++++ 11 files changed, 74 insertions(+), 22 deletions(-) diff --git a/Project.toml b/Project.toml index d8ea7ab4..78702565 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ ChainRulesCore = "1" ChainRulesTestUtils = "1" JET = "0.9" LinearAlgebra = "1" +SafeTestsets = "0.1" StableRNGs = "1" Test = "1" TestExtras = "0.2,0.3" @@ -28,10 +29,11 @@ julia = "1.10" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "JET", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote"] +test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote"] diff --git a/test/chainrules.jl b/test/chainrules.jl index a93c5270..ab895a4a 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -1,6 +1,10 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs using ChainRulesCore, ChainRulesTestUtils, Zygote using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD -using LinearAlgebra: UpperTriangular, Diagonal, Hermitian +using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! function remove_svdgauge_depence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S)) diff --git a/test/eig.jl b/test/eig.jl index da9c3961..d0dc13e3 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -1,3 +1,10 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: Diagonal +using MatrixAlgebraKit: diagview + @testset "eig_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) m = 54 diff --git a/test/eigh.jl b/test/eigh.jl index 8967f665..5a3c5a8a 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -1,3 +1,10 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: LinearAlgebra, Diagonal, I +using MatrixAlgebraKit: diagview + @testset "eigh_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) m = 54 diff --git a/test/lq.jl b/test/lq.jl index b1643080..10375e50 100644 --- a/test/lq.jl +++ b/test/lq.jl @@ -1,3 +1,9 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: diag, I + @testset "lq_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) m = 54 diff --git a/test/orthnull.jl b/test/orthnull.jl index 7784fc82..0a93c593 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -1,3 +1,9 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: LinearAlgebra, I + @testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) diff --git a/test/polar.jl b/test/polar.jl index 064e0e59..513654ea 100644 --- a/test/polar.jl +++ b/test/polar.jl @@ -1,3 +1,8 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: LinearAlgebra, I, isposdef using MatrixAlgebraKit: PolarViaSVD @testset "left_polar! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) diff --git a/test/qr.jl b/test/qr.jl index 9fb96408..81c64f0e 100644 --- a/test/qr.jl +++ b/test/qr.jl @@ -1,3 +1,9 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: diag, I + @testset "qr_compact! and qr_null! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) diff --git a/test/runtests.jl b/test/runtests.jl index 07a54125..541c7da8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,44 +1,40 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using ChainRulesTestUtils -using StableRNGs -using Aqua -using JET -using LinearAlgebra: LinearAlgebra, diag, Diagonal, I, isposdef, diagind, mul! -using MatrixAlgebraKit: diagview +using SafeTestsets -@testset "QR / LQ Decomposition" begin +@safetestset "QR / LQ Decomposition" begin include("qr.jl") include("lq.jl") end -@testset "Singular Value Decomposition" begin +@safetestset "Singular Value Decomposition" begin include("svd.jl") end -@testset "Hermitian Eigenvalue Decomposition" begin +@safetestset "Hermitian Eigenvalue Decomposition" begin include("eigh.jl") end -@testset "General Eigenvalue Decomposition" begin +@safetestset "General Eigenvalue Decomposition" begin include("eig.jl") end -@testset "Schur Decomposition" begin +@safetestset "Schur Decomposition" begin include("schur.jl") end -@testset "Polar Decomposition" begin +@safetestset "Polar Decomposition" begin include("polar.jl") end -@testset "Image and Null Space" begin +@safetestset "Image and Null Space" begin include("orthnull.jl") end -@testset "ChainRules" verbose = true begin +@safetestset "ChainRules" begin include("chainrules.jl") end -@testset "MatrixAlgebraKit.jl" begin - @testset "Code quality (Aqua.jl)" begin +@safetestset "MatrixAlgebraKit.jl" begin + @safetestset "Code quality (Aqua.jl)" begin + using MatrixAlgebraKit + using Aqua Aqua.test_all(MatrixAlgebraKit) end - @testset "Code linting (JET.jl)" begin + @safetestset "Code linting (JET.jl)" begin + using MatrixAlgebraKit + using JET JET.test_package(MatrixAlgebraKit; target_defined_modules=true) end end diff --git a/test/schur.jl b/test/schur.jl index 05f952b5..71404e89 100644 --- a/test/schur.jl +++ b/test/schur.jl @@ -1,3 +1,9 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: I + @testset "schur_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) m = 54 diff --git a/test/svd.jl b/test/svd.jl index ed7b502b..ee4b74e0 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -1,3 +1,10 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef +using MatrixAlgebraKit: diagview + @testset "svd_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) m = 54 From f0b5184a17b61161b164fb74c04522fb67d5f12e Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 9 Apr 2025 12:18:23 -0400 Subject: [PATCH 7/7] Bump to v0.1.2 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 78702565..254377c9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MatrixAlgebraKit" uuid = "6c742aac-3347-4629-af66-fc926824e5e4" authors = ["Jutho and contributors"] -version = "0.1.1" +version = "0.1.2" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"