diff --git a/.gitignore b/.gitignore index 4f2cc4b7..73664508 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ *.jl.*.cov *.jl.cov *.jl.mem +.*.swp Manifest.toml docs/build/ diff --git a/Project.toml b/Project.toml index d8ea7ab4..254377c9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MatrixAlgebraKit" uuid = "6c742aac-3347-4629-af66-fc926824e5e4" authors = ["Jutho and contributors"] -version = "0.1.1" +version = "0.1.2" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -18,6 +18,7 @@ ChainRulesCore = "1" ChainRulesTestUtils = "1" JET = "0.9" LinearAlgebra = "1" +SafeTestsets = "0.1" StableRNGs = "1" Test = "1" TestExtras = "0.2,0.3" @@ -28,10 +29,11 @@ julia = "1.10" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "JET", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote"] +test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote"] diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index 5ef1fd9b..e21842c6 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -176,4 +176,4 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A::AbstractMatrix, PWᴴ, return PWᴴ, right_polar_pullback end -end \ No newline at end of file +end diff --git a/src/algorithms.jl b/src/algorithms.jl index ea260aa2..fc247317 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -178,4 +178,4 @@ macro check_size(x, sz, size=:size) string($sz) szx == $sz || throw(DimensionMismatch($err)) end) -end \ No newline at end of file +end diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index e134121b..c34df721 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -215,4 +215,4 @@ function right_null!(A::AbstractMatrix, Nᴴ; kwargs...) else throw(ArgumentError("`right_null!` received unknown value `kind = $kind`")) end -end \ No newline at end of file +end diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index f135df10..73b59c97 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -11,10 +11,17 @@ function TruncationStrategy(; atol=nothing, rtol=nothing, maxrank=nothing) if isnothing(maxrank) && isnothing(atol) && isnothing(rtol) return NoTruncation() elseif isnothing(maxrank) - @assert isnothing(rtol) "TODO: rtol" - return trunctol(atol) + atol = @something atol 0 + rtol = @something rtol 0 + return TruncationKeepAbove(atol, rtol) else - return truncrank(maxrank) + if isnothing(atol) && isnothing(rtol) + return truncrank(maxrank) + else + atol = @something atol 0 + rtol = @something rtol 0 + return truncrank(maxrank) & TruncationKeepAbove(atol, rtol) + end end end @@ -82,6 +89,28 @@ Truncation strategy to discard the values that are larger than `atol` in absolut """ truncabove(atol) = TruncationKeepFiltered(≤(atol) ∘ abs) +""" + TruncationComposition(trunc1::TruncationStrategy, trunc2::TruncationStrategy) + +Compose two truncation strategies, keeping values common between the two strategies. +""" +struct TruncationComposition{T<:Tuple{Vararg{TruncationStrategy}}} <: + TruncationStrategy + components::T +end +function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy) + return TruncationComposition((trunc1, trunc2)) +end +function Base.:&(trunc1::TruncationComposition, trunc2::TruncationComposition) + return TruncationComposition((trunc1.components..., trunc2.components...)) +end +function Base.:&(trunc1::TruncationComposition, trunc2::TruncationStrategy) + return TruncationComposition((trunc1.components..., trunc2)) +end +function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationComposition) + return TruncationComposition((trunc1, trunc2.components...)) +end + # truncate! # --------- # Generic implementation: `findtruncated` followed by indexing @@ -147,6 +176,11 @@ function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove) return 1:i end +function findtruncated(values::AbstractVector, strategy::TruncationComposition) + inds = map(Base.Fix1(findtruncated, values), strategy.components) + return intersect(inds...) +end + """ TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm) diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index de7c3153..236184bd 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -62,7 +62,7 @@ of `kind`. the input matrix `A`. Always use the return value of the function as it may not always be possible to use the provided `CV` as output. -See also [`right_orth(!)`](@ref right_orth), [`left_orth(!)`](@ref left_orth), [`right_orth(!)`](@ref right_orth) +See also [`right_orth(!)`](@ref right_orth), [`left_null(!)`](@ref left_null), [`right_null(!)`](@ref right_null) """ function left_orth end function left_orth! end @@ -117,7 +117,7 @@ of `kind`. the input matrix `A`. Always use the return value of the function as it may not always be possible to use the provided `CVᴴ` as output. -See also [`left_orth(!)`](@ref left_orth), [`left_orth(!)`](@ref left_orth), [`right_orth(!)`](@ref right_orth) +See also [`left_orth(!)`](@ref left_orth), [`left_null(!)`](@ref left_null), [`right_null(!)`](@ref right_null) """ function right_orth end function right_orth! end @@ -224,4 +224,4 @@ function right_null!(A::AbstractMatrix; kwargs...) end function right_null(A::AbstractMatrix; kwargs...) return right_null!(copy_input(right_null, A); kwargs...) -end \ No newline at end of file +end diff --git a/src/pullbacks/polar.jl b/src/pullbacks/polar.jl index ed6fc17d..2eea389e 100644 --- a/src/pullbacks/polar.jl +++ b/src/pullbacks/polar.jl @@ -58,4 +58,4 @@ function right_polar_pullback!(ΔA::AbstractMatrix, PWᴴ, ΔPWᴴ) ΔA .+= PΔWᴴ end return ΔA -end \ No newline at end of file +end diff --git a/test/chainrules.jl b/test/chainrules.jl index 47de1c04..ab895a4a 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -1,6 +1,10 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs using ChainRulesCore, ChainRulesTestUtils, Zygote using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD -using LinearAlgebra: UpperTriangular, Diagonal, Hermitian +using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! function remove_svdgauge_depence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S)) @@ -356,4 +360,4 @@ end test_rrule(config, right_null, A; fkwargs=(; kind=:lqpos), output_tangent=ΔNᴴ, atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false) end -end \ No newline at end of file +end diff --git a/test/eig.jl b/test/eig.jl index da9c3961..d0dc13e3 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -1,3 +1,10 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: Diagonal +using MatrixAlgebraKit: diagview + @testset "eig_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) m = 54 diff --git a/test/eigh.jl b/test/eigh.jl index 8967f665..5a3c5a8a 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -1,3 +1,10 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: LinearAlgebra, Diagonal, I +using MatrixAlgebraKit: diagview + @testset "eigh_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) m = 54 diff --git a/test/lq.jl b/test/lq.jl index b1643080..10375e50 100644 --- a/test/lq.jl +++ b/test/lq.jl @@ -1,3 +1,9 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: diag, I + @testset "lq_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) m = 54 diff --git a/test/orthnull.jl b/test/orthnull.jl index 7b926b59..0a93c593 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -1,3 +1,9 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: LinearAlgebra, I + @testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) @@ -209,4 +215,4 @@ end end end end -end \ No newline at end of file +end diff --git a/test/polar.jl b/test/polar.jl index 064e0e59..513654ea 100644 --- a/test/polar.jl +++ b/test/polar.jl @@ -1,3 +1,8 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: LinearAlgebra, I, isposdef using MatrixAlgebraKit: PolarViaSVD @testset "left_polar! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) diff --git a/test/qr.jl b/test/qr.jl index 9fb96408..81c64f0e 100644 --- a/test/qr.jl +++ b/test/qr.jl @@ -1,3 +1,9 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: diag, I + @testset "qr_compact! and qr_null! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) diff --git a/test/runtests.jl b/test/runtests.jl index 07a54125..541c7da8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,44 +1,40 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using ChainRulesTestUtils -using StableRNGs -using Aqua -using JET -using LinearAlgebra: LinearAlgebra, diag, Diagonal, I, isposdef, diagind, mul! -using MatrixAlgebraKit: diagview +using SafeTestsets -@testset "QR / LQ Decomposition" begin +@safetestset "QR / LQ Decomposition" begin include("qr.jl") include("lq.jl") end -@testset "Singular Value Decomposition" begin +@safetestset "Singular Value Decomposition" begin include("svd.jl") end -@testset "Hermitian Eigenvalue Decomposition" begin +@safetestset "Hermitian Eigenvalue Decomposition" begin include("eigh.jl") end -@testset "General Eigenvalue Decomposition" begin +@safetestset "General Eigenvalue Decomposition" begin include("eig.jl") end -@testset "Schur Decomposition" begin +@safetestset "Schur Decomposition" begin include("schur.jl") end -@testset "Polar Decomposition" begin +@safetestset "Polar Decomposition" begin include("polar.jl") end -@testset "Image and Null Space" begin +@safetestset "Image and Null Space" begin include("orthnull.jl") end -@testset "ChainRules" verbose = true begin +@safetestset "ChainRules" begin include("chainrules.jl") end -@testset "MatrixAlgebraKit.jl" begin - @testset "Code quality (Aqua.jl)" begin +@safetestset "MatrixAlgebraKit.jl" begin + @safetestset "Code quality (Aqua.jl)" begin + using MatrixAlgebraKit + using Aqua Aqua.test_all(MatrixAlgebraKit) end - @testset "Code linting (JET.jl)" begin + @safetestset "Code linting (JET.jl)" begin + using MatrixAlgebraKit + using JET JET.test_package(MatrixAlgebraKit; target_defined_modules=true) end end diff --git a/test/schur.jl b/test/schur.jl index 05f952b5..71404e89 100644 --- a/test/schur.jl +++ b/test/schur.jl @@ -1,3 +1,9 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: I + @testset "schur_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) m = 54 diff --git a/test/svd.jl b/test/svd.jl index 974957f0..ee4b74e0 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -1,3 +1,10 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef +using MatrixAlgebraKit: diagview + @testset "svd_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) m = 54 @@ -108,3 +115,30 @@ end end end end + +@testset "svd_trunc! mix maxrank and tol for T = $T" for T in + (Float32, Float64, ComplexF32, + ComplexF64) + rng = StableRNG(123) + if LinearAlgebra.LAPACK.version() < v"3.12.0" + algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) + else + algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), + LAPACK_Jacobi()) + end + m = 4 + @testset "algorithm $alg" for alg in algs + U = qr_compact(randn(rng, T, m, m))[1] + S = Diagonal([0.9, 0.3, 0.1, 0.01]) + Vᴴ = qr_compact(randn(rng, T, m, m))[1] + A = U * S * Vᴴ + + U1, S1, V1ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=1)) + @test length(S1.diag) == 1 + @test S1.diag ≈ S.diag[1:1] rtol = sqrt(eps(real(T))) + + U2, S2, V2ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=3)) + @test length(S2.diag) == 2 + @test S2.diag ≈ S.diag[1:2] rtol = sqrt(eps(real(T))) + end +end