@@ -67,22 +67,30 @@ struct TruncationKeepFiltered{F} <: TruncationStrategy
6767 filter:: F
6868end
6969
70- struct TruncationKeepAbove{T<: Real } <: TruncationStrategy
70+ struct TruncationKeepAbove{T<: Real ,F } <: TruncationStrategy
7171 atol:: T
7272 rtol:: T
7373 p:: Int
74+ by:: F
75+ end
76+ function TruncationKeepAbove(; atol:: Real , rtol:: Real , p:: Int = 2 , by= abs)
77+ return TruncationKeepAbove(atol, rtol, p, by)
7478end
75- function TruncationKeepAbove(atol:: Real , rtol:: Real , p:: Int = 2 )
76- return TruncationKeepAbove(promote(atol, rtol). .. , p)
79+ function TruncationKeepAbove(atol:: Real , rtol:: Real , p:: Int = 2 , by = abs )
80+ return TruncationKeepAbove(promote(atol, rtol). .. , p, by )
7781end
7882
79- struct TruncationKeepBelow{T<: Real } <: TruncationStrategy
83+ struct TruncationKeepBelow{T<: Real ,F } <: TruncationStrategy
8084 atol:: T
8185 rtol:: T
8286 p:: Int
87+ by:: F
88+ end
89+ function TruncationKeepBelow(; atol:: Real , rtol:: Real , p:: Int = 2 , by= abs)
90+ return TruncationKeepBelow(atol, rtol, p, by)
8391end
84- function TruncationKeepBelow(atol:: Real , rtol:: Real , p:: Int = 2 )
85- return TruncationKeepBelow(promote(atol, rtol). .. , p)
92+ function TruncationKeepBelow(atol:: Real , rtol:: Real , p:: Int = 2 , by = abs )
93+ return TruncationKeepBelow(promote(atol, rtol). .. , p, by )
8694end
8795
8896# TODO : better names for these functions of the above types
@@ -94,18 +102,18 @@ Truncation strategy to keep the first `howmany` values when sorted according to
94102truncrank(howmany:: Int ; by= abs, rev= true ) = TruncationKeepSorted(howmany, by, rev)
95103
96104"""
97- trunctol(atol::Real)
105+ trunctol(atol::Real; by=abs )
98106
99- Truncation strategy to discard the values that are smaller than `atol` in absolute value .
107+ Truncation strategy to discard the values that are smaller than `atol` according to `by` .
100108"""
101- trunctol(atol) = TruncationKeepFiltered(≥ (atol) ∘ abs )
109+ trunctol(atol; by = abs ) = TruncationKeepFiltered(≥ (atol) ∘ by )
102110
103111"""
104- truncabove(atol::Real)
112+ truncabove(atol::Real; by=abs )
105113
106- Truncation strategy to discard the values that are larger than `atol` in absolute value .
114+ Truncation strategy to discard the values that are larger than `atol` according to `by` .
107115"""
108- truncabove(atol) = TruncationKeepFiltered(≤ (atol) ∘ abs )
116+ truncabove(atol; by = abs ) = TruncationKeepFiltered(≤ (atol) ∘ by )
109117
110118"""
111119 TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...)
@@ -177,17 +185,18 @@ Generic interface for finding truncated values of the spectrum of a decompositio
177185based on the `strategy`. The output should be a collection of indices specifying
178186which values to keep. `MatrixAlgebraKit.findtruncated` is used inside of the default
179187implementation of [`truncate!`](@ref) to perform the truncation. It does not assume that the
180- values are sorted. For a version that assumes the values are reverse sorted by
181- absolute value (which is the standard case for SVD) see
182- [`MatrixAlgebraKit.findtruncated_sorted`](@ref).
188+ values are sorted. For a version that assumes the values are reverse sorted (which is the
189+ standard case for SVD) see [`MatrixAlgebraKit.findtruncated_sorted`](@ref).
183190""" findtruncated
184191
185192@doc """
186193 MatrixAlgebraKit.findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)
187194
188- Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are sorted in reverse order by
189- absolute value. However, note that this assumption is not checked, so passing values that are not sorted
190- in that way can silently give unexpected results. This is used in the default implementation of
195+ Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are sorted in reverse order.
196+ They are assumed to be sorted in a way that is consistent with the truncation strategy,
197+ which generally means they are sorted by absolute value but some truncation strategies allow
198+ customizing that. However, note that this assumption is not checked, so passing values that are not sorted
199+ in the correct way can silently give unexpected results. This is used in the default implementation of
191200[`svd_trunc!`](@ref).
192201""" findtruncated_sorted
193202
@@ -212,21 +221,21 @@ end
212221
213222function findtruncated(values:: AbstractVector , strategy:: TruncationKeepBelow )
214223 atol = max(strategy. atol, strategy. rtol * norm(values, strategy. p))
215- return findall(≤ (atol), values)
224+ return findall(≤ (atol) ∘ strategy . by , values)
216225end
217226function findtruncated_sorted(values:: AbstractVector , strategy:: TruncationKeepBelow )
218227 atol = max(strategy. atol, strategy. rtol * norm(values, strategy. p))
219- i = searchsortedfirst(values, atol; by= abs , rev= true )
228+ i = searchsortedfirst(values, atol; by= strategy . by , rev= true )
220229 return i: length(values)
221230end
222231
223232function findtruncated(values:: AbstractVector , strategy:: TruncationKeepAbove )
224233 atol = max(strategy. atol, strategy. rtol * norm(values, strategy. p))
225- return findall(≥ (atol), values)
234+ return findall(≥ (atol) ∘ strategy . by , values)
226235end
227236function findtruncated_sorted(values:: AbstractVector , strategy:: TruncationKeepAbove )
228237 atol = max(strategy. atol, strategy. rtol * norm(values, strategy. p))
229- i = searchsortedlast(values, atol; by= abs , rev= true )
238+ i = searchsortedlast(values, atol; by= strategy . by , rev= true )
230239 return 1 : i
231240end
232241
0 commit comments