diff --git a/docs/src/dev_interface.md b/docs/src/dev_interface.md index 4482a5c6..52d44ca1 100644 --- a/docs/src/dev_interface.md +++ b/docs/src/dev_interface.md @@ -10,4 +10,6 @@ MatrixAlgebraKit.jl provides a developer interface for specifying custom algorit ```@docs; canonical=false MatrixAlgebraKit.default_algorithm MatrixAlgebraKit.select_algorithm +MatrixAlgebraKit.findtruncated +MatrixAlgebraKit.findtruncated_sorted ``` diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index a9f48393..54207e60 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -31,7 +31,8 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered VERSION >= v"1.11.0-DEV.469" && - eval(Expr(:public, :default_algorithm, :select_algorithm)) + eval(Expr(:public, :default_algorithm, :findtruncated, :findtruncated_sorted, + :select_algorithm)) include("common/defaults.jl") include("common/initialization.jl") diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 1898a010..38af507e 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -48,13 +48,13 @@ end # since these are implicitly discarded by selecting compact/full """ - TruncationKeepSorted(howmany::Int, sortby::Function, rev::Bool) + TruncationKeepSorted(howmany::Int, by::Function, rev::Bool) -Truncation strategy to keep the first `howmany` values when sorted according to `sortby` or the last `howmany` if `rev` is true. +Truncation strategy to keep the first `howmany` values when sorted according to `by` in increasing (decreasing) order if `rev` is false (true). """ struct TruncationKeepSorted{F} <: TruncationStrategy howmany::Int - sortby::F + by::F rev::Bool end @@ -70,14 +70,20 @@ end struct TruncationKeepAbove{T<:Real} <: TruncationStrategy atol::T rtol::T + p::Int +end +function TruncationKeepAbove(atol::Real, rtol::Real, p::Int=2) + return TruncationKeepAbove(promote(atol, rtol)..., p) end -TruncationKeepAbove(atol::Real, rtol::Real) = TruncationKeepAbove(promote(atol, rtol)...) struct TruncationKeepBelow{T<:Real} <: TruncationStrategy atol::T rtol::T + p::Int +end +function TruncationKeepBelow(atol::Real, rtol::Real, p::Int=2) + return TruncationKeepBelow(promote(atol, rtol)..., p) end -TruncationKeepBelow(atol::Real, rtol::Real) = TruncationKeepBelow(promote(atol, rtol)...) # TODO: better names for these functions of the above types """ @@ -137,7 +143,7 @@ Generic interface for post-truncating a decomposition, specified in `out`. """ truncate! # TODO: should we return a view? function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy) - ind = findtruncated(diagview(S), strategy) + ind = findtruncated_sorted(diagview(S), strategy) return U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :] end function truncate!(::typeof(eig_trunc!), (D, V), strategy::TruncationStrategy) @@ -164,15 +170,38 @@ end # findtruncated # ------------- # specific implementations for finding truncated values +@doc """ + MatrixAlgebraKit.findtruncated(values::AbstractVector, strategy::TruncationStrategy) + +Generic interface for finding truncated values of the spectrum of a decomposition +based on the `strategy`. The output should be a collection of indices specifying +which values to keep. `MatrixAlgebraKit.findtruncated` is used inside of the default +implementation of [`truncate!`](@ref) to perform the truncation. It does not assume that the +values are sorted. For a version that assumes the values are reverse sorted by +absolute value (which is the standard case for SVD) see +[`MatrixAlgebraKit.findtruncated_sorted`](@ref). +""" findtruncated + +@doc """ + MatrixAlgebraKit.findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy) + +Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are sorted in reverse order by +absolute value. However, note that this assumption is not checked, so passing values that are not sorted +in that way can silently give unexpected results. This is used in the default implementation of +[`svd_trunc!`](@ref). +""" findtruncated_sorted + findtruncated(values::AbstractVector, ::NoTruncation) = Colon() # TODO: this may also permute the eigenvalues, decide if we want to allow this or not # can be solved by going to simply sorting the resulting `ind` function findtruncated(values::AbstractVector, strategy::TruncationKeepSorted) - sorted = sortperm(values; by=strategy.sortby, rev=strategy.rev) - howmany = min(strategy.howmany, length(sorted)) - ind = sorted[1:howmany] - return ind # TODO: consider sort!(ind) + howmany = min(strategy.howmany, length(values)) + return partialsortperm(values, 1:howmany; by=strategy.by, rev=strategy.rev) +end +function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepSorted) + howmany = min(strategy.howmany, length(values)) + return 1:howmany end # TODO: consider if worth using that values are sorted when filter is `<` or `>`. @@ -182,13 +211,22 @@ function findtruncated(values::AbstractVector, strategy::TruncationKeepFiltered) end function findtruncated(values::AbstractVector, strategy::TruncationKeepBelow) - atol = max(strategy.atol, strategy.rtol * first(values)) - i = @something findfirst(≤(atol), values) length(values) + 1 + atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) + return findall(≤(atol), values) +end +function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepBelow) + atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) + i = searchsortedfirst(values, atol; by=abs, rev=true) return i:length(values) end + function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove) - atol = max(strategy.atol, strategy.rtol * first(values)) - i = @something findlast(≥(atol), values) 0 + atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) + return findall(≥(atol), values) +end +function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepAbove) + atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) + i = searchsortedlast(values, atol; by=abs, rev=true) return 1:i end @@ -197,6 +235,11 @@ function findtruncated(values::AbstractVector, strategy::TruncationIntersection) return intersect(inds...) end +# Generic fallback. +function findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy) + return findtruncated(values, strategy) +end + """ TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm) diff --git a/test/truncate.jl b/test/truncate.jl index 2e67b894..5c2aedfc 100644 --- a/test/truncate.jl +++ b/test/truncate.jl @@ -2,7 +2,7 @@ using MatrixAlgebraKit using Test using TestExtras using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbove, - TruncationStrategy, findtruncated + TruncationKeepBelow, TruncationStrategy, findtruncated @testset "truncate" begin trunc = @constinferred TruncationStrategy() @@ -18,7 +18,7 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbov @test trunc isa TruncationKeepSorted @test trunc == truncrank(10) @test trunc.howmany == 10 - @test trunc.sortby == abs + @test trunc.by == abs @test trunc.rev == true trunc = @constinferred TruncationStrategy(; atol=1e-2, rtol=1e-3, maxrank=10) @@ -28,7 +28,15 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbov @test trunc.components[2] == TruncationKeepAbove(1e-2, 1e-3) values = [1, 0.9, 0.5, 0.3, 0.01] - @test @constinferred(findtruncated(values, truncrank(2))) == [1, 2] + @test @constinferred(findtruncated(values, truncrank(2))) == 1:2 @test @constinferred(findtruncated(values, truncrank(2; rev=false))) == [5, 4] @test @constinferred(findtruncated(values, truncrank(2; by=-))) == [5, 4] + + values = [1, 0.9, 0.5, 0.3, 0.01] + @test @constinferred(findtruncated(values, TruncationKeepAbove(0.4, 0.0))) == 1:3 + @test @constinferred(findtruncated(values, TruncationKeepBelow(0.4, 0.0))) == 4:5 + + values = [0.01, 1, 0.9, 0.3, 0.5] + @test @constinferred(findtruncated(values, TruncationKeepAbove(0.4, 0.0))) == [2, 3, 5] + @test @constinferred(findtruncated(values, TruncationKeepBelow(0.4, 0.0))) == [1, 4] end