diff --git a/Project.toml b/Project.toml index 507d81df..6736baa5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MatrixAlgebraKit" uuid = "6c742aac-3347-4629-af66-fc926824e5e4" authors = ["Jutho and contributors"] -version = "0.2.1" +version = "0.2.2" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/algorithms.jl b/src/algorithms.jl index f559b42e..6f9a4d49 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -99,7 +99,6 @@ function select_algorithm(f::F, ::Type{A}, alg::Alg=nothing; kwargs...) where {F throw(ArgumentError("Unknown alg $alg")) end - @doc """ MatrixAlgebraKit.default_algorithm(f, A; kwargs...) MatrixAlgebraKit.default_algorithm(f, ::Type{TA}; kwargs...) where {TA} diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 38af507e..0baecdbe 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -67,22 +67,30 @@ struct TruncationKeepFiltered{F} <: TruncationStrategy filter::F end -struct TruncationKeepAbove{T<:Real} <: TruncationStrategy +struct TruncationKeepAbove{T<:Real,F} <: TruncationStrategy atol::T rtol::T p::Int + by::F +end +function TruncationKeepAbove(; atol::Real, rtol::Real, p::Int=2, by=abs) + return TruncationKeepAbove(atol, rtol, p, by) end -function TruncationKeepAbove(atol::Real, rtol::Real, p::Int=2) - return TruncationKeepAbove(promote(atol, rtol)..., p) +function TruncationKeepAbove(atol::Real, rtol::Real, p::Int=2, by=abs) + return TruncationKeepAbove(promote(atol, rtol)..., p, by) end -struct TruncationKeepBelow{T<:Real} <: TruncationStrategy +struct TruncationKeepBelow{T<:Real,F} <: TruncationStrategy atol::T rtol::T p::Int + by::F +end +function TruncationKeepBelow(; atol::Real, rtol::Real, p::Int=2, by=abs) + return TruncationKeepBelow(atol, rtol, p, by) end -function TruncationKeepBelow(atol::Real, rtol::Real, p::Int=2) - return TruncationKeepBelow(promote(atol, rtol)..., p) +function TruncationKeepBelow(atol::Real, rtol::Real, p::Int=2, by=abs) + return TruncationKeepBelow(promote(atol, rtol)..., p, by) end # TODO: better names for these functions of the above types @@ -94,18 +102,18 @@ Truncation strategy to keep the first `howmany` values when sorted according to truncrank(howmany::Int; by=abs, rev=true) = TruncationKeepSorted(howmany, by, rev) """ - trunctol(atol::Real) + trunctol(atol::Real; by=abs) -Truncation strategy to discard the values that are smaller than `atol` in absolute value. +Truncation strategy to discard the values that are smaller than `atol` according to `by`. """ -trunctol(atol) = TruncationKeepFiltered(≥(atol) ∘ abs) +trunctol(atol; by=abs) = TruncationKeepFiltered(≥(atol) ∘ by) """ - truncabove(atol::Real) + truncabove(atol::Real; by=abs) -Truncation strategy to discard the values that are larger than `atol` in absolute value. +Truncation strategy to discard the values that are larger than `atol` according to `by`. """ -truncabove(atol) = TruncationKeepFiltered(≤(atol) ∘ abs) +truncabove(atol; by=abs) = TruncationKeepFiltered(≤(atol) ∘ by) """ TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...) @@ -177,17 +185,18 @@ Generic interface for finding truncated values of the spectrum of a decompositio based on the `strategy`. The output should be a collection of indices specifying which values to keep. `MatrixAlgebraKit.findtruncated` is used inside of the default implementation of [`truncate!`](@ref) to perform the truncation. It does not assume that the -values are sorted. For a version that assumes the values are reverse sorted by -absolute value (which is the standard case for SVD) see -[`MatrixAlgebraKit.findtruncated_sorted`](@ref). +values are sorted. For a version that assumes the values are reverse sorted (which is the +standard case for SVD) see [`MatrixAlgebraKit.findtruncated_sorted`](@ref). """ findtruncated @doc """ MatrixAlgebraKit.findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy) -Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are sorted in reverse order by -absolute value. However, note that this assumption is not checked, so passing values that are not sorted -in that way can silently give unexpected results. This is used in the default implementation of +Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are sorted in reverse order. +They are assumed to be sorted in a way that is consistent with the truncation strategy, +which generally means they are sorted by absolute value but some truncation strategies allow +customizing that. However, note that this assumption is not checked, so passing values that are not sorted +in the correct way can silently give unexpected results. This is used in the default implementation of [`svd_trunc!`](@ref). """ findtruncated_sorted @@ -212,21 +221,21 @@ end function findtruncated(values::AbstractVector, strategy::TruncationKeepBelow) atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) - return findall(≤(atol), values) + return findall(≤(atol) ∘ strategy.by, values) end function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepBelow) atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) - i = searchsortedfirst(values, atol; by=abs, rev=true) + i = searchsortedfirst(values, atol; by=strategy.by, rev=true) return i:length(values) end function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove) atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) - return findall(≥(atol), values) + return findall(≥(atol) ∘ strategy.by, values) end function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepAbove) atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p)) - i = searchsortedlast(values, atol; by=abs, rev=true) + i = searchsortedlast(values, atol; by=strategy.by, rev=true) return 1:i end diff --git a/test/truncate.jl b/test/truncate.jl index 5c2aedfc..73c9aff3 100644 --- a/test/truncate.jl +++ b/test/truncate.jl @@ -2,7 +2,8 @@ using MatrixAlgebraKit using Test using TestExtras using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbove, - TruncationKeepBelow, TruncationStrategy, findtruncated + TruncationKeepBelow, TruncationStrategy, findtruncated, + findtruncated_sorted @testset "truncate" begin trunc = @constinferred TruncationStrategy() @@ -27,16 +28,45 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbov @test trunc.components[1] == truncrank(10) @test trunc.components[2] == TruncationKeepAbove(1e-2, 1e-3) - values = [1, 0.9, 0.5, 0.3, 0.01] + values = [1, 0.9, 0.5, -0.3, 0.01] @test @constinferred(findtruncated(values, truncrank(2))) == 1:2 @test @constinferred(findtruncated(values, truncrank(2; rev=false))) == [5, 4] - @test @constinferred(findtruncated(values, truncrank(2; by=-))) == [5, 4] + @test @constinferred(findtruncated(values, truncrank(2; by=((-) ∘ abs)))) == [5, 4] + @test @constinferred(findtruncated_sorted(values, truncrank(2))) === 1:2 - values = [1, 0.9, 0.5, 0.3, 0.01] - @test @constinferred(findtruncated(values, TruncationKeepAbove(0.4, 0.0))) == 1:3 - @test @constinferred(findtruncated(values, TruncationKeepBelow(0.4, 0.0))) == 4:5 + values = [1, 0.9, 0.5, -0.3, 0.01] + for strategy in (TruncationKeepAbove(; atol=0.4, rtol=0), + TruncationKeepAbove(0.4, 0)) + @test @constinferred(findtruncated(values, strategy)) == 1:3 + @test @constinferred(findtruncated_sorted(values, strategy)) === 1:3 + end + for strategy in (TruncationKeepBelow(; atol=0.4, rtol=0), + TruncationKeepBelow(0.4, 0)) + @test @constinferred(findtruncated(values, strategy)) == 4:5 + @test @constinferred(findtruncated_sorted(values, strategy)) === 4:5 + end - values = [0.01, 1, 0.9, 0.3, 0.5] - @test @constinferred(findtruncated(values, TruncationKeepAbove(0.4, 0.0))) == [2, 3, 5] - @test @constinferred(findtruncated(values, TruncationKeepBelow(0.4, 0.0))) == [1, 4] + values = [0.01, 1, 0.9, -0.3, 0.5] + for strategy in (TruncationKeepAbove(; atol=0.4, rtol=0), + TruncationKeepAbove(; atol=0.4, rtol=0, by=abs), + TruncationKeepAbove(0.4, 0), + TruncationKeepAbove(; atol=0.2, rtol=0.0, by=identity)) + @test @constinferred(findtruncated(values, strategy)) == [2, 3, 5] + end + for strategy in (TruncationKeepAbove(; atol=0.2, rtol=0), + TruncationKeepAbove(; atol=0.2, rtol=0, by=abs), + TruncationKeepAbove(0.2, 0)) + @test @constinferred(findtruncated(values, strategy)) == [2, 3, 4, 5] + end + for strategy in (TruncationKeepBelow(; atol=0.4, rtol=0), + TruncationKeepBelow(; atol=0.4, rtol=0, by=abs), + TruncationKeepBelow(0.4, 0), + TruncationKeepBelow(; atol=0.2, rtol=0.0, by=identity)) + @test @constinferred(findtruncated(values, strategy)) == [1, 4] + end + for strategy in (TruncationKeepBelow(; atol=0.2, rtol=0), + TruncationKeepBelow(; atol=0.2, rtol=0, by=abs), + TruncationKeepBelow(0.2, 0)) + @test @constinferred(findtruncated(values, strategy)) == [1] + end end