Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/dev_interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ MatrixAlgebraKit.jl provides a developer interface for specifying custom algorit
```@docs; canonical=false
MatrixAlgebraKit.default_algorithm
MatrixAlgebraKit.select_algorithm
MatrixAlgebraKit.findtruncated
MatrixAlgebraKit.findtruncated_sorted
```
3 changes: 2 additions & 1 deletion src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered

VERSION >= v"1.11.0-DEV.469" &&
eval(Expr(:public, :default_algorithm, :select_algorithm))
eval(Expr(:public, :default_algorithm, :findtruncated, :findtruncated_sorted,
:select_algorithm))

include("common/defaults.jl")
include("common/initialization.jl")
Expand Down
71 changes: 57 additions & 14 deletions src/implementations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@
# since these are implicitly discarded by selecting compact/full

"""
TruncationKeepSorted(howmany::Int, sortby::Function, rev::Bool)
TruncationKeepSorted(howmany::Int, by::Function, rev::Bool)

Truncation strategy to keep the first `howmany` values when sorted according to `sortby` or the last `howmany` if `rev` is true.
Truncation strategy to keep the first `howmany` values when sorted according to `by` in increasing (decreasing) order if `rev` is false (true).
"""
struct TruncationKeepSorted{F} <: TruncationStrategy
howmany::Int
sortby::F
by::F
rev::Bool
end

Expand All @@ -70,14 +70,20 @@
struct TruncationKeepAbove{T<:Real} <: TruncationStrategy
atol::T
rtol::T
p::Int
end
function TruncationKeepAbove(atol::Real, rtol::Real, p::Int=2)
return TruncationKeepAbove(promote(atol, rtol)..., p)
end
TruncationKeepAbove(atol::Real, rtol::Real) = TruncationKeepAbove(promote(atol, rtol)...)

struct TruncationKeepBelow{T<:Real} <: TruncationStrategy
atol::T
rtol::T
p::Int
end
function TruncationKeepBelow(atol::Real, rtol::Real, p::Int=2)
return TruncationKeepBelow(promote(atol, rtol)..., p)
end
TruncationKeepBelow(atol::Real, rtol::Real) = TruncationKeepBelow(promote(atol, rtol)...)

# TODO: better names for these functions of the above types
"""
Expand Down Expand Up @@ -137,7 +143,7 @@
""" truncate!
# TODO: should we return a view?
function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy)
ind = findtruncated(diagview(S), strategy)
ind = findtruncated_sorted(diagview(S), strategy)
return U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]
end
function truncate!(::typeof(eig_trunc!), (D, V), strategy::TruncationStrategy)
Expand All @@ -164,15 +170,38 @@
# findtruncated
# -------------
# specific implementations for finding truncated values
@doc """
MatrixAlgebraKit.findtruncated(values::AbstractVector, strategy::TruncationStrategy)

Generic interface for finding truncated values of the spectrum of a decomposition
based on the `strategy`. The output should be a collection of indices specifying
which values to keep. `MatrixAlgebraKit.findtruncated` is used inside of the default
implementation of [`truncate!`](@ref) to perform the truncation. It does not assume that the
values are sorted. For a version that assumes the values are reverse sorted by
absolute value (which is the standard case for SVD) see
[`MatrixAlgebraKit.findtruncated_sorted`](@ref).
""" findtruncated

@doc """
MatrixAlgebraKit.findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)

Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are sorted in reverse order by
absolute value. However, note that this assumption is not checked, so passing values that are not sorted
in that way can silently give unexpected results. This is used in the default implementation of
[`svd_trunc!`](@ref).
""" findtruncated_sorted

findtruncated(values::AbstractVector, ::NoTruncation) = Colon()

# TODO: this may also permute the eigenvalues, decide if we want to allow this or not
# can be solved by going to simply sorting the resulting `ind`
function findtruncated(values::AbstractVector, strategy::TruncationKeepSorted)
sorted = sortperm(values; by=strategy.sortby, rev=strategy.rev)
howmany = min(strategy.howmany, length(sorted))
ind = sorted[1:howmany]
return ind # TODO: consider sort!(ind)
howmany = min(strategy.howmany, length(values))
return partialsortperm(values, 1:howmany; by=strategy.by, rev=strategy.rev)
end
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepSorted)
howmany = min(strategy.howmany, length(values))
return 1:howmany
end

# TODO: consider if worth using that values are sorted when filter is `<` or `>`.
Expand All @@ -182,13 +211,22 @@
end

function findtruncated(values::AbstractVector, strategy::TruncationKeepBelow)
atol = max(strategy.atol, strategy.rtol * first(values))
i = @something findfirst(≤(atol), values) length(values) + 1
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
return findall(≤(atol), values)
end
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepBelow)
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
i = searchsortedfirst(values, atol; by=abs, rev=true)

Check warning on line 219 in src/implementations/truncation.jl

View check run for this annotation

Codecov / codecov/patch

src/implementations/truncation.jl#L217-L219

Added lines #L217 - L219 were not covered by tests
return i:length(values)
end

function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove)
atol = max(strategy.atol, strategy.rtol * first(values))
i = @something findlast(≥(atol), values) 0
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
return findall(≥(atol), values)
end
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepAbove)
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
i = searchsortedlast(values, atol; by=abs, rev=true)
return 1:i
end

Expand All @@ -197,6 +235,11 @@
return intersect(inds...)
end

# Generic fallback.
function findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)
return findtruncated(values, strategy)
end

"""
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)

Expand Down
14 changes: 11 additions & 3 deletions test/truncate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using MatrixAlgebraKit
using Test
using TestExtras
using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbove,
TruncationStrategy, findtruncated
TruncationKeepBelow, TruncationStrategy, findtruncated

@testset "truncate" begin
trunc = @constinferred TruncationStrategy()
Expand All @@ -18,7 +18,7 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbov
@test trunc isa TruncationKeepSorted
@test trunc == truncrank(10)
@test trunc.howmany == 10
@test trunc.sortby == abs
@test trunc.by == abs
@test trunc.rev == true

trunc = @constinferred TruncationStrategy(; atol=1e-2, rtol=1e-3, maxrank=10)
Expand All @@ -28,7 +28,15 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbov
@test trunc.components[2] == TruncationKeepAbove(1e-2, 1e-3)

values = [1, 0.9, 0.5, 0.3, 0.01]
@test @constinferred(findtruncated(values, truncrank(2))) == [1, 2]
@test @constinferred(findtruncated(values, truncrank(2))) == 1:2
@test @constinferred(findtruncated(values, truncrank(2; rev=false))) == [5, 4]
@test @constinferred(findtruncated(values, truncrank(2; by=-))) == [5, 4]

values = [1, 0.9, 0.5, 0.3, 0.01]
@test @constinferred(findtruncated(values, TruncationKeepAbove(0.4, 0.0))) == 1:3
@test @constinferred(findtruncated(values, TruncationKeepBelow(0.4, 0.0))) == 4:5

values = [0.01, 1, 0.9, 0.3, 0.5]
@test @constinferred(findtruncated(values, TruncationKeepAbove(0.4, 0.0))) == [2, 3, 5]
@test @constinferred(findtruncated(values, TruncationKeepBelow(0.4, 0.0))) == [1, 4]
end
Loading