Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 36 additions & 14 deletions src/implementations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@
# since these are implicitly discarded by selecting compact/full

"""
TruncationKeepSorted(howmany::Int, sortby::Function, rev::Bool)
TruncationKeepSorted(howmany::Int, by::Function, rev::Bool)

Truncation strategy to keep the first `howmany` values when sorted according to `sortby` or the last `howmany` if `rev` is true.
Truncation strategy to keep the first `howmany` values when sorted according to `by` in increasing (decreasing) order if `rev` is false (true).
"""
struct TruncationKeepSorted{F} <: TruncationStrategy
howmany::Int
sortby::F
by::F
rev::Bool
end

Expand All @@ -70,14 +70,20 @@
struct TruncationKeepAbove{T<:Real} <: TruncationStrategy
atol::T
rtol::T
p::Int
end
function TruncationKeepAbove(atol::Real, rtol::Real, p::Int=2)
return TruncationKeepAbove(promote(atol, rtol)..., p)
end
TruncationKeepAbove(atol::Real, rtol::Real) = TruncationKeepAbove(promote(atol, rtol)...)

struct TruncationKeepBelow{T<:Real} <: TruncationStrategy
atol::T
rtol::T
p::Int
end
function TruncationKeepBelow(atol::Real, rtol::Real, p::Int=2)
return TruncationKeepBelow(promote(atol, rtol)..., p)
end
TruncationKeepBelow(atol::Real, rtol::Real) = TruncationKeepBelow(promote(atol, rtol)...)

# TODO: better names for these functions of the above types
"""
Expand Down Expand Up @@ -137,7 +143,7 @@
""" truncate!
# TODO: should we return a view?
function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy)
ind = findtruncated(diagview(S), strategy)
ind = findtruncated_sorted(diagview(S), strategy)
return U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]
end
function truncate!(::typeof(eig_trunc!), (D, V), strategy::TruncationStrategy)
Expand Down Expand Up @@ -169,10 +175,12 @@
# TODO: this may also permute the eigenvalues, decide if we want to allow this or not
# can be solved by going to simply sorting the resulting `ind`
function findtruncated(values::AbstractVector, strategy::TruncationKeepSorted)
sorted = sortperm(values; by=strategy.sortby, rev=strategy.rev)
howmany = min(strategy.howmany, length(sorted))
ind = sorted[1:howmany]
return ind # TODO: consider sort!(ind)
howmany = min(strategy.howmany, length(values))
return partialsortperm(values, 1:howmany; by=strategy.by, rev=strategy.rev)
end
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepSorted)
howmany = min(strategy.howmany, length(values))
return 1:howmany
end

# TODO: consider if worth using that values are sorted when filter is `<` or `>`.
Expand All @@ -182,13 +190,22 @@
end

function findtruncated(values::AbstractVector, strategy::TruncationKeepBelow)
atol = max(strategy.atol, strategy.rtol * first(values))
i = @something findfirst(≤(atol), values) length(values) + 1
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
return findall(≤(atol), values)
end
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepBelow)
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
i = searchsortedfirst(values, atol; by=abs, rev=true)

Check warning on line 198 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L196-L198

Added lines #L196 - L198 were not covered by tests
return i:length(values)
end

function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove)
atol = max(strategy.atol, strategy.rtol * first(values))
i = @something findlast(≥(atol), values) 0
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
return findall(≥(atol), values)
end
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepAbove)
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
i = searchsortedlast(values, atol; by=abs, rev=true)
return 1:i
end

Expand All @@ -197,6 +214,11 @@
return intersect(inds...)
end

# Generic fallback.
function findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)
return findtruncated(values, strategy)
end

"""
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)

Expand Down
14 changes: 11 additions & 3 deletions test/truncate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using MatrixAlgebraKit
using Test
using TestExtras
using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbove,
TruncationStrategy, findtruncated
TruncationKeepBelow, TruncationStrategy, findtruncated

@testset "truncate" begin
trunc = @constinferred TruncationStrategy()
Expand All @@ -18,7 +18,7 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbov
@test trunc isa TruncationKeepSorted
@test trunc == truncrank(10)
@test trunc.howmany == 10
@test trunc.sortby == abs
@test trunc.by == abs
@test trunc.rev == true

trunc = @constinferred TruncationStrategy(; atol=1e-2, rtol=1e-3, maxrank=10)
Expand All @@ -28,7 +28,15 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbov
@test trunc.components[2] == TruncationKeepAbove(1e-2, 1e-3)

values = [1, 0.9, 0.5, 0.3, 0.01]
@test @constinferred(findtruncated(values, truncrank(2))) == [1, 2]
@test @constinferred(findtruncated(values, truncrank(2))) == 1:2
@test @constinferred(findtruncated(values, truncrank(2; rev=false))) == [5, 4]
@test @constinferred(findtruncated(values, truncrank(2; by=-))) == [5, 4]

values = [1, 0.9, 0.5, 0.3, 0.01]
@test @constinferred(findtruncated(values, TruncationKeepAbove(0.4, 0.0))) == 1:3
@test @constinferred(findtruncated(values, TruncationKeepBelow(0.4, 0.0))) == 4:5

values = [0.01, 1, 0.9, 0.3, 0.5]
@test @constinferred(findtruncated(values, TruncationKeepAbove(0.4, 0.0))) == [2, 3, 5]
@test @constinferred(findtruncated(values, TruncationKeepBelow(0.4, 0.0))) == [1, 4]
end
Loading