Skip to content

Commit 2d6443e

Browse files
authored
Refactor truncation module + add truncerror (#55)
* Refactor truncation implementation * relax norm types * Add TruncationError * update docs * Add some tests * move definition of `&` * Add `notrunc` docstrings * update `findtruncated_sorted` docstring
1 parent 8ebae19 commit 2d6443e

File tree

9 files changed

+317
-188
lines changed

9 files changed

+317
-188
lines changed

docs/src/user_interface/truncations.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,18 @@ CollapsedDocStrings = true
88
Currently, truncations are supported through the following different methods:
99

1010
```@docs; canonical=false
11+
notrunc
1112
truncrank
1213
trunctol
1314
truncabove
15+
truncerror
16+
```
17+
18+
It is additionally possible to combine truncation strategies by making use of the `&` operator.
19+
For example, truncating to a maximal dimension `10`, and discarding all values below `1e-6` would be achieved by:
20+
21+
```julia
22+
maxdim = 10
23+
tol = 1e-6
24+
combined_trunc = truncrank(maxdim) & trunctol(tol)
1425
```

src/MatrixAlgebraKit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ,
3737
CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer,
3838
ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi, ROCSOLVER_DivideAndConquer, ROCSOLVER_Bisection,
3939
DiagonalAlgorithm
40-
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered
40+
export truncrank, trunctol, truncabove, TruncationKeepSorted, TruncationKeepFiltered, truncerror
4141

4242
VERSION >= v"1.11.0-DEV.469" &&
4343
eval(Expr(:public, :default_algorithm, :findtruncated, :findtruncated_sorted,
@@ -55,6 +55,7 @@ include("common/gauge.jl")
5555
include("yalapack.jl")
5656
include("algorithms.jl")
5757
include("interface/decompositions.jl")
58+
include("interface/truncation.jl")
5859
include("interface/qr.jl")
5960
include("interface/lq.jl")
6061
include("interface/svd.jl")

src/algorithms.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,71 @@ If this is not possible, for example when the output size is not known a priori
131131
this function may return `nothing`.
132132
""" initialize_output
133133

134+
# Truncation strategy
135+
# -------------------
136+
"""
137+
abstract type TruncationStrategy end
138+
139+
Supertype to denote different strategies for truncated decompositions that are implemented via post-truncation.
140+
141+
See also [`truncate!`](@ref)
142+
"""
143+
abstract type TruncationStrategy end
144+
145+
@doc """
146+
MatrixAlgebraKit.select_truncation(trunc)
147+
148+
Construct a [`TruncationStrategy`](@ref) from the given `NamedTuple` of keywords or input strategy.
149+
""" select_truncation
150+
151+
function select_truncation(trunc)
152+
if isnothing(trunc)
153+
return NoTruncation()
154+
elseif trunc isa NamedTuple
155+
return TruncationStrategy(; trunc...)
156+
elseif trunc isa TruncationStrategy
157+
return trunc
158+
else
159+
return throw(ArgumentError("Unknown truncation strategy: $trunc"))
160+
end
161+
end
162+
163+
@doc """
164+
MatrixAlgebraKit.findtruncated(values::AbstractVector, strategy::TruncationStrategy)
165+
166+
Generic interface for finding truncated values of the spectrum of a decomposition
167+
based on the `strategy`. The output should be a collection of indices specifying
168+
which values to keep. `MatrixAlgebraKit.findtruncated` is used inside of the default
169+
implementation of [`truncate!`](@ref) to perform the truncation. It does not assume that the
170+
values are sorted. For a version that assumes the values are reverse sorted (which is the
171+
standard case for SVD) see [`MatrixAlgebraKit.findtruncated_sorted`](@ref).
172+
""" findtruncated
173+
174+
@doc """
175+
MatrixAlgebraKit.findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)
176+
177+
Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are real and
178+
sorted in descending order, as typically obtained by the SVD. This assumption is not
179+
checked, and this is used in the default implementation of [`svd_trunc!`](@ref).
180+
""" findtruncated_sorted
181+
182+
"""
183+
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)
184+
185+
Generic wrapper type for algorithms that consist of first using `alg`, followed by a
186+
truncation through `trunc`.
187+
"""
188+
struct TruncatedAlgorithm{A,T} <: AbstractAlgorithm
189+
alg::A
190+
trunc::T
191+
end
192+
193+
@doc """
194+
truncate!(f, out, strategy::TruncationStrategy)
195+
196+
Generic interface for post-truncating a decomposition, specified in `out`.
197+
""" truncate!
198+
134199
# Utility macros
135200
# --------------
136201

src/implementations/truncation.jl

Lines changed: 32 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -1,155 +1,6 @@
1-
"""
2-
abstract type TruncationStrategy end
3-
4-
Supertype to denote different strategies for truncated decompositions that are implemented via post-truncation.
5-
6-
See also [`truncate!`](@ref)
7-
"""
8-
abstract type TruncationStrategy end
9-
10-
function TruncationStrategy(; atol=nothing, rtol=nothing, maxrank=nothing)
11-
if isnothing(maxrank) && isnothing(atol) && isnothing(rtol)
12-
return NoTruncation()
13-
elseif isnothing(maxrank)
14-
atol = @something atol 0
15-
rtol = @something rtol 0
16-
return TruncationKeepAbove(atol, rtol)
17-
else
18-
if isnothing(atol) && isnothing(rtol)
19-
return truncrank(maxrank)
20-
else
21-
atol = @something atol 0
22-
rtol = @something rtol 0
23-
return truncrank(maxrank) & TruncationKeepAbove(atol, rtol)
24-
end
25-
end
26-
end
27-
28-
"""
29-
NoTruncation()
30-
31-
Trivial truncation strategy that keeps all values, mostly for testing purposes.
32-
"""
33-
struct NoTruncation <: TruncationStrategy end
34-
35-
function select_truncation(trunc)
36-
if isnothing(trunc)
37-
return NoTruncation()
38-
elseif trunc isa NamedTuple
39-
return TruncationStrategy(; trunc...)
40-
elseif trunc isa TruncationStrategy
41-
return trunc
42-
else
43-
return throw(ArgumentError("Unknown truncation strategy: $trunc"))
44-
end
45-
end
46-
47-
# TODO: how do we deal with sorting/filters that treat zeros differently
48-
# since these are implicitly discarded by selecting compact/full
49-
50-
"""
51-
TruncationKeepSorted(howmany::Int, by::Function, rev::Bool)
52-
53-
Truncation strategy to keep the first `howmany` values when sorted according to `by` in increasing (decreasing) order if `rev` is false (true).
54-
"""
55-
struct TruncationKeepSorted{F} <: TruncationStrategy
56-
howmany::Int
57-
by::F
58-
rev::Bool
59-
end
60-
61-
"""
62-
TruncationKeepFiltered(filter::Function)
63-
64-
Truncation strategy to keep the values for which `filter` returns true.
65-
"""
66-
struct TruncationKeepFiltered{F} <: TruncationStrategy
67-
filter::F
68-
end
69-
70-
struct TruncationKeepAbove{T<:Real,F} <: TruncationStrategy
71-
atol::T
72-
rtol::T
73-
p::Int
74-
by::F
75-
end
76-
function TruncationKeepAbove(; atol::Real, rtol::Real, p::Int=2, by=abs)
77-
return TruncationKeepAbove(atol, rtol, p, by)
78-
end
79-
function TruncationKeepAbove(atol::Real, rtol::Real, p::Int=2, by=abs)
80-
return TruncationKeepAbove(promote(atol, rtol)..., p, by)
81-
end
82-
83-
struct TruncationKeepBelow{T<:Real,F} <: TruncationStrategy
84-
atol::T
85-
rtol::T
86-
p::Int
87-
by::F
88-
end
89-
function TruncationKeepBelow(; atol::Real, rtol::Real, p::Int=2, by=abs)
90-
return TruncationKeepBelow(atol, rtol, p, by)
91-
end
92-
function TruncationKeepBelow(atol::Real, rtol::Real, p::Int=2, by=abs)
93-
return TruncationKeepBelow(promote(atol, rtol)..., p, by)
94-
end
95-
96-
# TODO: better names for these functions of the above types
97-
"""
98-
truncrank(howmany::Int; by=abs, rev=true)
99-
100-
Truncation strategy to keep the first `howmany` values when sorted according to `by` or the last `howmany` if `rev` is true.
101-
"""
102-
truncrank(howmany::Int; by=abs, rev=true) = TruncationKeepSorted(howmany, by, rev)
103-
104-
"""
105-
trunctol(atol::Real; by=abs)
106-
107-
Truncation strategy to discard the values that are smaller than `atol` according to `by`.
108-
"""
109-
trunctol(atol; by=abs) = TruncationKeepFiltered((atol) by)
110-
111-
"""
112-
truncabove(atol::Real; by=abs)
113-
114-
Truncation strategy to discard the values that are larger than `atol` according to `by`.
115-
"""
116-
truncabove(atol; by=abs) = TruncationKeepFiltered((atol) by)
117-
118-
"""
119-
TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...)
120-
121-
Composition of multiple truncation strategies, keeping values common between them.
122-
"""
123-
struct TruncationIntersection{T<:Tuple{Vararg{TruncationStrategy}}} <:
124-
TruncationStrategy
125-
components::T
126-
end
127-
function TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...)
128-
return TruncationIntersection((trunc, truncs...))
129-
end
130-
131-
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
132-
return TruncationIntersection((trunc1, trunc2))
133-
end
134-
function Base.:&(trunc1::TruncationIntersection, trunc2::TruncationIntersection)
135-
return TruncationIntersection((trunc1.components..., trunc2.components...))
136-
end
137-
function Base.:&(trunc1::TruncationIntersection, trunc2::TruncationStrategy)
138-
return TruncationIntersection((trunc1.components..., trunc2))
139-
end
140-
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationIntersection)
141-
return TruncationIntersection((trunc1, trunc2.components...))
142-
end
143-
1441
# truncate!
1452
# ---------
1463
# Generic implementation: `findtruncated` followed by indexing
147-
@doc """
148-
truncate!(f, out, strategy::TruncationStrategy)
149-
150-
Generic interface for post-truncating a decomposition, specified in `out`.
151-
""" truncate!
152-
# TODO: should we return a view?
1534
function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy)
1545
ind = findtruncated_sorted(diagview(S), strategy)
1556
return U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]
@@ -178,32 +29,8 @@ end
17829
# findtruncated
17930
# -------------
18031
# specific implementations for finding truncated values
181-
@doc """
182-
MatrixAlgebraKit.findtruncated(values::AbstractVector, strategy::TruncationStrategy)
183-
184-
Generic interface for finding truncated values of the spectrum of a decomposition
185-
based on the `strategy`. The output should be a collection of indices specifying
186-
which values to keep. `MatrixAlgebraKit.findtruncated` is used inside of the default
187-
implementation of [`truncate!`](@ref) to perform the truncation. It does not assume that the
188-
values are sorted. For a version that assumes the values are reverse sorted (which is the
189-
standard case for SVD) see [`MatrixAlgebraKit.findtruncated_sorted`](@ref).
190-
""" findtruncated
191-
192-
@doc """
193-
MatrixAlgebraKit.findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)
194-
195-
Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are sorted in reverse order.
196-
They are assumed to be sorted in a way that is consistent with the truncation strategy,
197-
which generally means they are sorted by absolute value but some truncation strategies allow
198-
customizing that. However, note that this assumption is not checked, so passing values that are not sorted
199-
in the correct way can silently give unexpected results. This is used in the default implementation of
200-
[`svd_trunc!`](@ref).
201-
""" findtruncated_sorted
202-
20332
findtruncated(values::AbstractVector, ::NoTruncation) = Colon()
20433

205-
# TODO: this may also permute the eigenvalues, decide if we want to allow this or not
206-
# can be solved by going to simply sorting the resulting `ind`
20734
function findtruncated(values::AbstractVector, strategy::TruncationKeepSorted)
20835
howmany = min(strategy.howmany, length(values))
20936
return partialsortperm(values, 1:howmany; by=strategy.by, rev=strategy.rev)
@@ -243,19 +70,40 @@ function findtruncated(values::AbstractVector, strategy::TruncationIntersection)
24370
inds = map(Base.Fix1(findtruncated, values), strategy.components)
24471
return intersect(inds...)
24572
end
73+
function findtruncated_sorted(values::AbstractVector, strategy::TruncationIntersection)
74+
inds = map(Base.Fix1(findtruncated_sorted, values), strategy.components)
75+
return intersect(inds...)
76+
end
24677

247-
# Generic fallback.
248-
function findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)
249-
return findtruncated(values, strategy)
78+
function findtruncated(values::AbstractVector, strategy::TruncationError)
79+
I = sortperm(values; by=abs, rev=true)
80+
I′ = _truncerr_impl(values, I, strategy)
81+
return I[I′]
25082
end
83+
function findtruncated_sorted(values::AbstractVector, strategy::TruncationError)
84+
I = eachindex(values)
85+
I′ = _truncerr_impl(values, I, strategy)
86+
return I[I′]
87+
end
88+
function _truncerr_impl(values::AbstractVector, I, strategy::TruncationError)
89+
Nᵖ = sum(Base.Fix2(^, strategy.p) abs, values)
90+
ϵᵖ = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * Nᵖ)
91+
ϵᵖ Nᵖ && return Base.OneTo(0)
25192

252-
"""
253-
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)
93+
truncerrᵖ = zero(real(eltype(values)))
94+
rank = length(values)
95+
for i in reverse(I)
96+
truncerrᵖ += abs(values[i])^strategy.p
97+
if truncerrᵖ ϵᵖ
98+
break
99+
else
100+
rank -= 1
101+
end
102+
end
103+
return Base.OneTo(rank)
104+
end
254105

255-
Generic wrapper type for algorithms that consist of first using `alg`, followed by a
256-
truncation through `trunc`.
257-
"""
258-
struct TruncatedAlgorithm{A,T} <: AbstractAlgorithm
259-
alg::A
260-
trunc::T
106+
# Generic fallback
107+
function findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)
108+
return findtruncated(values, strategy)
261109
end

0 commit comments

Comments
 (0)