Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MatrixAlgebraKit"
uuid = "6c742aac-3347-4629-af66-fc926824e5e4"
authors = ["Jutho <[email protected]> and contributors"]
version = "0.2.1"
version = "0.2.2"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
1 change: 0 additions & 1 deletion src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ function select_algorithm(f::F, ::Type{A}, alg::Alg=nothing; kwargs...) where {F
throw(ArgumentError("Unknown alg $alg"))
end


@doc """
MatrixAlgebraKit.default_algorithm(f, A; kwargs...)
MatrixAlgebraKit.default_algorithm(f, ::Type{TA}; kwargs...) where {TA}
Expand Down
53 changes: 31 additions & 22 deletions src/implementations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,30 @@
filter::F
end

struct TruncationKeepAbove{T<:Real} <: TruncationStrategy
struct TruncationKeepAbove{T<:Real,F} <: TruncationStrategy
atol::T
rtol::T
p::Int
by::F
end
function TruncationKeepAbove(; atol::Real, rtol::Real, p::Int=2, by=abs)
return TruncationKeepAbove(atol, rtol, p, by)
end
function TruncationKeepAbove(atol::Real, rtol::Real, p::Int=2)
return TruncationKeepAbove(promote(atol, rtol)..., p)
function TruncationKeepAbove(atol::Real, rtol::Real, p::Int=2, by=abs)
return TruncationKeepAbove(promote(atol, rtol)..., p, by)
end

struct TruncationKeepBelow{T<:Real} <: TruncationStrategy
struct TruncationKeepBelow{T<:Real,F} <: TruncationStrategy
atol::T
rtol::T
p::Int
by::F
end
function TruncationKeepBelow(; atol::Real, rtol::Real, p::Int=2, by=abs)
return TruncationKeepBelow(atol, rtol, p, by)
end
function TruncationKeepBelow(atol::Real, rtol::Real, p::Int=2)
return TruncationKeepBelow(promote(atol, rtol)..., p)
function TruncationKeepBelow(atol::Real, rtol::Real, p::Int=2, by=abs)
return TruncationKeepBelow(promote(atol, rtol)..., p, by)
end

# TODO: better names for these functions of the above types
Expand All @@ -94,18 +102,18 @@
truncrank(howmany::Int; by=abs, rev=true) = TruncationKeepSorted(howmany, by, rev)

"""
trunctol(atol::Real)
trunctol(atol::Real; by=abs)

Truncation strategy to discard the values that are smaller than `atol` in absolute value.
Truncation strategy to discard the values that are smaller than `atol` according to `by`.
"""
trunctol(atol) = TruncationKeepFiltered(≥(atol) ∘ abs)
trunctol(atol; by=abs) = TruncationKeepFiltered(≥(atol) ∘ by)

"""
truncabove(atol::Real)
truncabove(atol::Real; by=abs)

Truncation strategy to discard the values that are larger than `atol` in absolute value.
Truncation strategy to discard the values that are larger than `atol` according to `by`.
"""
truncabove(atol) = TruncationKeepFiltered(≤(atol) ∘ abs)
truncabove(atol; by=abs) = TruncationKeepFiltered(≤(atol) ∘ by)

Check warning on line 116 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L116

Added line #L116 was not covered by tests

"""
TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...)
Expand Down Expand Up @@ -177,17 +185,18 @@
based on the `strategy`. The output should be a collection of indices specifying
which values to keep. `MatrixAlgebraKit.findtruncated` is used inside of the default
implementation of [`truncate!`](@ref) to perform the truncation. It does not assume that the
values are sorted. For a version that assumes the values are reverse sorted by
absolute value (which is the standard case for SVD) see
[`MatrixAlgebraKit.findtruncated_sorted`](@ref).
values are sorted. For a version that assumes the values are reverse sorted (which is the
standard case for SVD) see [`MatrixAlgebraKit.findtruncated_sorted`](@ref).
""" findtruncated

@doc """
MatrixAlgebraKit.findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)

Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are sorted in reverse order by
absolute value. However, note that this assumption is not checked, so passing values that are not sorted
in that way can silently give unexpected results. This is used in the default implementation of
Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are sorted in reverse order.
They are assumed to be sorted in a way that is consistent with the truncation strategy,
which generally means they are sorted by absolute value but some truncation strategies allow
customizing that. However, note that this assumption is not checked, so passing values that are not sorted
in the correct way can silently give unexpected results. This is used in the default implementation of
[`svd_trunc!`](@ref).
""" findtruncated_sorted

Expand All @@ -212,21 +221,21 @@

function findtruncated(values::AbstractVector, strategy::TruncationKeepBelow)
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
return findall(≤(atol), values)
return findall(≤(atol) ∘ strategy.by, values)
end
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepBelow)
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
i = searchsortedfirst(values, atol; by=abs, rev=true)
i = searchsortedfirst(values, atol; by=strategy.by, rev=true)
return i:length(values)
end

function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove)
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
return findall(≥(atol), values)
return findall(≥(atol) ∘ strategy.by, values)
end
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepAbove)
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
i = searchsortedlast(values, atol; by=abs, rev=true)
i = searchsortedlast(values, atol; by=strategy.by, rev=true)
return 1:i
end

Expand Down
48 changes: 39 additions & 9 deletions test/truncate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ using MatrixAlgebraKit
using Test
using TestExtras
using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbove,
TruncationKeepBelow, TruncationStrategy, findtruncated
TruncationKeepBelow, TruncationStrategy, findtruncated,
findtruncated_sorted

@testset "truncate" begin
trunc = @constinferred TruncationStrategy()
Expand All @@ -27,16 +28,45 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbov
@test trunc.components[1] == truncrank(10)
@test trunc.components[2] == TruncationKeepAbove(1e-2, 1e-3)

values = [1, 0.9, 0.5, 0.3, 0.01]
values = [1, 0.9, 0.5, -0.3, 0.01]
@test @constinferred(findtruncated(values, truncrank(2))) == 1:2
@test @constinferred(findtruncated(values, truncrank(2; rev=false))) == [5, 4]
@test @constinferred(findtruncated(values, truncrank(2; by=-))) == [5, 4]
@test @constinferred(findtruncated(values, truncrank(2; by=((-) ∘ abs)))) == [5, 4]
@test @constinferred(findtruncated_sorted(values, truncrank(2))) === 1:2

values = [1, 0.9, 0.5, 0.3, 0.01]
@test @constinferred(findtruncated(values, TruncationKeepAbove(0.4, 0.0))) == 1:3
@test @constinferred(findtruncated(values, TruncationKeepBelow(0.4, 0.0))) == 4:5
values = [1, 0.9, 0.5, -0.3, 0.01]
for strategy in (TruncationKeepAbove(; atol=0.4, rtol=0),
TruncationKeepAbove(0.4, 0))
@test @constinferred(findtruncated(values, strategy)) == 1:3
@test @constinferred(findtruncated_sorted(values, strategy)) === 1:3
end
for strategy in (TruncationKeepBelow(; atol=0.4, rtol=0),
TruncationKeepBelow(0.4, 0))
@test @constinferred(findtruncated(values, strategy)) == 4:5
@test @constinferred(findtruncated_sorted(values, strategy)) === 4:5
end

values = [0.01, 1, 0.9, 0.3, 0.5]
@test @constinferred(findtruncated(values, TruncationKeepAbove(0.4, 0.0))) == [2, 3, 5]
@test @constinferred(findtruncated(values, TruncationKeepBelow(0.4, 0.0))) == [1, 4]
values = [0.01, 1, 0.9, -0.3, 0.5]
for strategy in (TruncationKeepAbove(; atol=0.4, rtol=0),
TruncationKeepAbove(; atol=0.4, rtol=0, by=abs),
TruncationKeepAbove(0.4, 0),
TruncationKeepAbove(; atol=0.2, rtol=0.0, by=identity))
@test @constinferred(findtruncated(values, strategy)) == [2, 3, 5]
end
for strategy in (TruncationKeepAbove(; atol=0.2, rtol=0),
TruncationKeepAbove(; atol=0.2, rtol=0, by=abs),
TruncationKeepAbove(0.2, 0))
@test @constinferred(findtruncated(values, strategy)) == [2, 3, 4, 5]
end
for strategy in (TruncationKeepBelow(; atol=0.4, rtol=0),
TruncationKeepBelow(; atol=0.4, rtol=0, by=abs),
TruncationKeepBelow(0.4, 0),
TruncationKeepBelow(; atol=0.2, rtol=0.0, by=identity))
@test @constinferred(findtruncated(values, strategy)) == [1, 4]
end
for strategy in (TruncationKeepBelow(; atol=0.2, rtol=0),
TruncationKeepBelow(; atol=0.2, rtol=0, by=abs),
TruncationKeepBelow(0.2, 0))
@test @constinferred(findtruncated(values, strategy)) == [1]
end
end
Loading