|
1 | | -""" |
2 | | - abstract type TruncationStrategy end |
3 | | -
|
4 | | -Supertype to denote different strategies for truncated decompositions that are implemented via post-truncation. |
5 | | -
|
6 | | -See also [`truncate!`](@ref) |
7 | | -""" |
8 | | -abstract type TruncationStrategy end |
9 | | - |
10 | | -function TruncationStrategy(; atol=nothing, rtol=nothing, maxrank=nothing) |
11 | | - if isnothing(maxrank) && isnothing(atol) && isnothing(rtol) |
12 | | - return NoTruncation() |
13 | | - elseif isnothing(maxrank) |
14 | | - atol = @something atol 0 |
15 | | - rtol = @something rtol 0 |
16 | | - return TruncationKeepAbove(atol, rtol) |
17 | | - else |
18 | | - if isnothing(atol) && isnothing(rtol) |
19 | | - return truncrank(maxrank) |
20 | | - else |
21 | | - atol = @something atol 0 |
22 | | - rtol = @something rtol 0 |
23 | | - return truncrank(maxrank) & TruncationKeepAbove(atol, rtol) |
24 | | - end |
25 | | - end |
26 | | -end |
27 | | - |
28 | | -""" |
29 | | - NoTruncation() |
30 | | -
|
31 | | -Trivial truncation strategy that keeps all values, mostly for testing purposes. |
32 | | -""" |
33 | | -struct NoTruncation <: TruncationStrategy end |
34 | | - |
35 | | -function select_truncation(trunc) |
36 | | - if isnothing(trunc) |
37 | | - return NoTruncation() |
38 | | - elseif trunc isa NamedTuple |
39 | | - return TruncationStrategy(; trunc...) |
40 | | - elseif trunc isa TruncationStrategy |
41 | | - return trunc |
42 | | - else |
43 | | - return throw(ArgumentError("Unknown truncation strategy: $trunc")) |
44 | | - end |
45 | | -end |
46 | | - |
47 | | -# TODO: how do we deal with sorting/filters that treat zeros differently |
48 | | -# since these are implicitly discarded by selecting compact/full |
49 | | - |
50 | | -""" |
51 | | - TruncationKeepSorted(howmany::Int, by::Function, rev::Bool) |
52 | | -
|
53 | | -Truncation strategy to keep the first `howmany` values when sorted according to `by` in increasing (decreasing) order if `rev` is false (true). |
54 | | -""" |
55 | | -struct TruncationKeepSorted{F} <: TruncationStrategy |
56 | | - howmany::Int |
57 | | - by::F |
58 | | - rev::Bool |
59 | | -end |
60 | | - |
61 | | -""" |
62 | | - TruncationKeepFiltered(filter::Function) |
63 | | -
|
64 | | -Truncation strategy to keep the values for which `filter` returns true. |
65 | | -""" |
66 | | -struct TruncationKeepFiltered{F} <: TruncationStrategy |
67 | | - filter::F |
68 | | -end |
69 | | - |
70 | | -struct TruncationKeepAbove{T<:Real,F} <: TruncationStrategy |
71 | | - atol::T |
72 | | - rtol::T |
73 | | - p::Int |
74 | | - by::F |
75 | | -end |
76 | | -function TruncationKeepAbove(; atol::Real, rtol::Real, p::Int=2, by=abs) |
77 | | - return TruncationKeepAbove(atol, rtol, p, by) |
78 | | -end |
79 | | -function TruncationKeepAbove(atol::Real, rtol::Real, p::Int=2, by=abs) |
80 | | - return TruncationKeepAbove(promote(atol, rtol)..., p, by) |
81 | | -end |
82 | | - |
83 | | -struct TruncationKeepBelow{T<:Real,F} <: TruncationStrategy |
84 | | - atol::T |
85 | | - rtol::T |
86 | | - p::Int |
87 | | - by::F |
88 | | -end |
89 | | -function TruncationKeepBelow(; atol::Real, rtol::Real, p::Int=2, by=abs) |
90 | | - return TruncationKeepBelow(atol, rtol, p, by) |
91 | | -end |
92 | | -function TruncationKeepBelow(atol::Real, rtol::Real, p::Int=2, by=abs) |
93 | | - return TruncationKeepBelow(promote(atol, rtol)..., p, by) |
94 | | -end |
95 | | - |
96 | | -# TODO: better names for these functions of the above types |
97 | | -""" |
98 | | - truncrank(howmany::Int; by=abs, rev=true) |
99 | | -
|
100 | | -Truncation strategy to keep the first `howmany` values when sorted according to `by` or the last `howmany` if `rev` is true. |
101 | | -""" |
102 | | -truncrank(howmany::Int; by=abs, rev=true) = TruncationKeepSorted(howmany, by, rev) |
103 | | - |
104 | | -""" |
105 | | - trunctol(atol::Real; by=abs) |
106 | | -
|
107 | | -Truncation strategy to discard the values that are smaller than `atol` according to `by`. |
108 | | -""" |
109 | | -trunctol(atol; by=abs) = TruncationKeepFiltered(≥(atol) ∘ by) |
110 | | - |
111 | | -""" |
112 | | - truncabove(atol::Real; by=abs) |
113 | | -
|
114 | | -Truncation strategy to discard the values that are larger than `atol` according to `by`. |
115 | | -""" |
116 | | -truncabove(atol; by=abs) = TruncationKeepFiltered(≤(atol) ∘ by) |
117 | | - |
118 | | -""" |
119 | | - TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...) |
120 | | -
|
121 | | -Composition of multiple truncation strategies, keeping values common between them. |
122 | | -""" |
123 | | -struct TruncationIntersection{T<:Tuple{Vararg{TruncationStrategy}}} <: |
124 | | - TruncationStrategy |
125 | | - components::T |
126 | | -end |
127 | | -function TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...) |
128 | | - return TruncationIntersection((trunc, truncs...)) |
129 | | -end |
130 | | - |
131 | | -function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy) |
132 | | - return TruncationIntersection((trunc1, trunc2)) |
133 | | -end |
134 | | -function Base.:&(trunc1::TruncationIntersection, trunc2::TruncationIntersection) |
135 | | - return TruncationIntersection((trunc1.components..., trunc2.components...)) |
136 | | -end |
137 | | -function Base.:&(trunc1::TruncationIntersection, trunc2::TruncationStrategy) |
138 | | - return TruncationIntersection((trunc1.components..., trunc2)) |
139 | | -end |
140 | | -function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationIntersection) |
141 | | - return TruncationIntersection((trunc1, trunc2.components...)) |
142 | | -end |
143 | | - |
144 | 1 | # truncate! |
145 | 2 | # --------- |
146 | 3 | # Generic implementation: `findtruncated` followed by indexing |
147 | | -@doc """ |
148 | | - truncate!(f, out, strategy::TruncationStrategy) |
149 | | -
|
150 | | -Generic interface for post-truncating a decomposition, specified in `out`. |
151 | | -""" truncate! |
152 | | -# TODO: should we return a view? |
153 | 4 | function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy) |
154 | 5 | ind = findtruncated_sorted(diagview(S), strategy) |
155 | 6 | return U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :] |
|
178 | 29 | # findtruncated |
179 | 30 | # ------------- |
180 | 31 | # specific implementations for finding truncated values |
181 | | -@doc """ |
182 | | - MatrixAlgebraKit.findtruncated(values::AbstractVector, strategy::TruncationStrategy) |
183 | | -
|
184 | | -Generic interface for finding truncated values of the spectrum of a decomposition |
185 | | -based on the `strategy`. The output should be a collection of indices specifying |
186 | | -which values to keep. `MatrixAlgebraKit.findtruncated` is used inside of the default |
187 | | -implementation of [`truncate!`](@ref) to perform the truncation. It does not assume that the |
188 | | -values are sorted. For a version that assumes the values are reverse sorted (which is the |
189 | | -standard case for SVD) see [`MatrixAlgebraKit.findtruncated_sorted`](@ref). |
190 | | -""" findtruncated |
191 | | - |
192 | | -@doc """ |
193 | | - MatrixAlgebraKit.findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy) |
194 | | -
|
195 | | -Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are sorted in reverse order. |
196 | | -They are assumed to be sorted in a way that is consistent with the truncation strategy, |
197 | | -which generally means they are sorted by absolute value but some truncation strategies allow |
198 | | -customizing that. However, note that this assumption is not checked, so passing values that are not sorted |
199 | | -in the correct way can silently give unexpected results. This is used in the default implementation of |
200 | | -[`svd_trunc!`](@ref). |
201 | | -""" findtruncated_sorted |
202 | | - |
203 | 32 | findtruncated(values::AbstractVector, ::NoTruncation) = Colon() |
204 | 33 |
|
205 | | -# TODO: this may also permute the eigenvalues, decide if we want to allow this or not |
206 | | -# can be solved by going to simply sorting the resulting `ind` |
207 | 34 | function findtruncated(values::AbstractVector, strategy::TruncationKeepSorted) |
208 | 35 | howmany = min(strategy.howmany, length(values)) |
209 | 36 | return partialsortperm(values, 1:howmany; by=strategy.by, rev=strategy.rev) |
@@ -243,19 +70,40 @@ function findtruncated(values::AbstractVector, strategy::TruncationIntersection) |
243 | 70 | inds = map(Base.Fix1(findtruncated, values), strategy.components) |
244 | 71 | return intersect(inds...) |
245 | 72 | end |
| 73 | +function findtruncated_sorted(values::AbstractVector, strategy::TruncationIntersection) |
| 74 | + inds = map(Base.Fix1(findtruncated_sorted, values), strategy.components) |
| 75 | + return intersect(inds...) |
| 76 | +end |
246 | 77 |
|
247 | | -# Generic fallback. |
248 | | -function findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy) |
249 | | - return findtruncated(values, strategy) |
| 78 | +function findtruncated(values::AbstractVector, strategy::TruncationError) |
| 79 | + I = sortperm(values; by=abs, rev=true) |
| 80 | + I′ = _truncerr_impl(values, I, strategy) |
| 81 | + return I[I′] |
250 | 82 | end |
| 83 | +function findtruncated_sorted(values::AbstractVector, strategy::TruncationError) |
| 84 | + I = eachindex(values) |
| 85 | + I′ = _truncerr_impl(values, I, strategy) |
| 86 | + return I[I′] |
| 87 | +end |
| 88 | +function _truncerr_impl(values::AbstractVector, I, strategy::TruncationError) |
| 89 | + Nᵖ = sum(Base.Fix2(^, strategy.p) ∘ abs, values) |
| 90 | + ϵᵖ = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * Nᵖ) |
| 91 | + ϵᵖ ≥ Nᵖ && return Base.OneTo(0) |
251 | 92 |
|
252 | | -""" |
253 | | - TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm) |
| 93 | + truncerrᵖ = zero(real(eltype(values))) |
| 94 | + rank = length(values) |
| 95 | + for i in reverse(I) |
| 96 | + truncerrᵖ += abs(values[i])^strategy.p |
| 97 | + if truncerrᵖ ≥ ϵᵖ |
| 98 | + break |
| 99 | + else |
| 100 | + rank -= 1 |
| 101 | + end |
| 102 | + end |
| 103 | + return Base.OneTo(rank) |
| 104 | +end |
254 | 105 |
|
255 | | -Generic wrapper type for algorithms that consist of first using `alg`, followed by a |
256 | | -truncation through `trunc`. |
257 | | -""" |
258 | | -struct TruncatedAlgorithm{A,T} <: AbstractAlgorithm |
259 | | - alg::A |
260 | | - trunc::T |
| 106 | +# Generic fallback |
| 107 | +function findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy) |
| 108 | + return findtruncated(values, strategy) |
261 | 109 | end |
0 commit comments