Skip to content

Commit 2acaccd

Browse files
authored
Properly use abs/by in TruncationKeepAbove/Below (#33)
1 parent 765a6a4 commit 2acaccd

File tree

4 files changed

+71
-33
lines changed

4 files changed

+71
-33
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MatrixAlgebraKit"
22
uuid = "6c742aac-3347-4629-af66-fc926824e5e4"
33
authors = ["Jutho <[email protected]> and contributors"]
4-
version = "0.2.1"
4+
version = "0.2.2"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/algorithms.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ function select_algorithm(f::F, ::Type{A}, alg::Alg=nothing; kwargs...) where {F
9999
throw(ArgumentError("Unknown alg $alg"))
100100
end
101101

102-
103102
@doc """
104103
MatrixAlgebraKit.default_algorithm(f, A; kwargs...)
105104
MatrixAlgebraKit.default_algorithm(f, ::Type{TA}; kwargs...) where {TA}

src/implementations/truncation.jl

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -67,22 +67,30 @@ struct TruncationKeepFiltered{F} <: TruncationStrategy
6767
filter::F
6868
end
6969

70-
struct TruncationKeepAbove{T<:Real} <: TruncationStrategy
70+
struct TruncationKeepAbove{T<:Real,F} <: TruncationStrategy
7171
atol::T
7272
rtol::T
7373
p::Int
74+
by::F
75+
end
76+
function TruncationKeepAbove(; atol::Real, rtol::Real, p::Int=2, by=abs)
77+
return TruncationKeepAbove(atol, rtol, p, by)
7478
end
75-
function TruncationKeepAbove(atol::Real, rtol::Real, p::Int=2)
76-
return TruncationKeepAbove(promote(atol, rtol)..., p)
79+
function TruncationKeepAbove(atol::Real, rtol::Real, p::Int=2, by=abs)
80+
return TruncationKeepAbove(promote(atol, rtol)..., p, by)
7781
end
7882

79-
struct TruncationKeepBelow{T<:Real} <: TruncationStrategy
83+
struct TruncationKeepBelow{T<:Real,F} <: TruncationStrategy
8084
atol::T
8185
rtol::T
8286
p::Int
87+
by::F
88+
end
89+
function TruncationKeepBelow(; atol::Real, rtol::Real, p::Int=2, by=abs)
90+
return TruncationKeepBelow(atol, rtol, p, by)
8391
end
84-
function TruncationKeepBelow(atol::Real, rtol::Real, p::Int=2)
85-
return TruncationKeepBelow(promote(atol, rtol)..., p)
92+
function TruncationKeepBelow(atol::Real, rtol::Real, p::Int=2, by=abs)
93+
return TruncationKeepBelow(promote(atol, rtol)..., p, by)
8694
end
8795

8896
# TODO: better names for these functions of the above types
@@ -94,18 +102,18 @@ Truncation strategy to keep the first `howmany` values when sorted according to
94102
truncrank(howmany::Int; by=abs, rev=true) = TruncationKeepSorted(howmany, by, rev)
95103

96104
"""
97-
trunctol(atol::Real)
105+
trunctol(atol::Real; by=abs)
98106
99-
Truncation strategy to discard the values that are smaller than `atol` in absolute value.
107+
Truncation strategy to discard the values that are smaller than `atol` according to `by`.
100108
"""
101-
trunctol(atol) = TruncationKeepFiltered((atol) abs)
109+
trunctol(atol; by=abs) = TruncationKeepFiltered((atol) by)
102110

103111
"""
104-
truncabove(atol::Real)
112+
truncabove(atol::Real; by=abs)
105113
106-
Truncation strategy to discard the values that are larger than `atol` in absolute value.
114+
Truncation strategy to discard the values that are larger than `atol` according to `by`.
107115
"""
108-
truncabove(atol) = TruncationKeepFiltered((atol) abs)
116+
truncabove(atol; by=abs) = TruncationKeepFiltered((atol) by)
109117

110118
"""
111119
TruncationIntersection(trunc::TruncationStrategy, truncs::TruncationStrategy...)
@@ -177,17 +185,18 @@ Generic interface for finding truncated values of the spectrum of a decompositio
177185
based on the `strategy`. The output should be a collection of indices specifying
178186
which values to keep. `MatrixAlgebraKit.findtruncated` is used inside of the default
179187
implementation of [`truncate!`](@ref) to perform the truncation. It does not assume that the
180-
values are sorted. For a version that assumes the values are reverse sorted by
181-
absolute value (which is the standard case for SVD) see
182-
[`MatrixAlgebraKit.findtruncated_sorted`](@ref).
188+
values are sorted. For a version that assumes the values are reverse sorted (which is the
189+
standard case for SVD) see [`MatrixAlgebraKit.findtruncated_sorted`](@ref).
183190
""" findtruncated
184191

185192
@doc """
186193
MatrixAlgebraKit.findtruncated_sorted(values::AbstractVector, strategy::TruncationStrategy)
187194
188-
Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are sorted in reverse order by
189-
absolute value. However, note that this assumption is not checked, so passing values that are not sorted
190-
in that way can silently give unexpected results. This is used in the default implementation of
195+
Like [`MatrixAlgebraKit.findtruncated`](@ref) but assumes that the values are sorted in reverse order.
196+
They are assumed to be sorted in a way that is consistent with the truncation strategy,
197+
which generally means they are sorted by absolute value but some truncation strategies allow
198+
customizing that. However, note that this assumption is not checked, so passing values that are not sorted
199+
in the correct way can silently give unexpected results. This is used in the default implementation of
191200
[`svd_trunc!`](@ref).
192201
""" findtruncated_sorted
193202

@@ -212,21 +221,21 @@ end
212221

213222
function findtruncated(values::AbstractVector, strategy::TruncationKeepBelow)
214223
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
215-
return findall((atol), values)
224+
return findall((atol) strategy.by, values)
216225
end
217226
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepBelow)
218227
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
219-
i = searchsortedfirst(values, atol; by=abs, rev=true)
228+
i = searchsortedfirst(values, atol; by=strategy.by, rev=true)
220229
return i:length(values)
221230
end
222231

223232
function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove)
224233
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
225-
return findall((atol), values)
234+
return findall((atol) strategy.by, values)
226235
end
227236
function findtruncated_sorted(values::AbstractVector, strategy::TruncationKeepAbove)
228237
atol = max(strategy.atol, strategy.rtol * norm(values, strategy.p))
229-
i = searchsortedlast(values, atol; by=abs, rev=true)
238+
i = searchsortedlast(values, atol; by=strategy.by, rev=true)
230239
return 1:i
231240
end
232241

test/truncate.jl

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ using MatrixAlgebraKit
22
using Test
33
using TestExtras
44
using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbove,
5-
TruncationKeepBelow, TruncationStrategy, findtruncated
5+
TruncationKeepBelow, TruncationStrategy, findtruncated,
6+
findtruncated_sorted
67

78
@testset "truncate" begin
89
trunc = @constinferred TruncationStrategy()
@@ -27,16 +28,45 @@ using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbov
2728
@test trunc.components[1] == truncrank(10)
2829
@test trunc.components[2] == TruncationKeepAbove(1e-2, 1e-3)
2930

30-
values = [1, 0.9, 0.5, 0.3, 0.01]
31+
values = [1, 0.9, 0.5, -0.3, 0.01]
3132
@test @constinferred(findtruncated(values, truncrank(2))) == 1:2
3233
@test @constinferred(findtruncated(values, truncrank(2; rev=false))) == [5, 4]
33-
@test @constinferred(findtruncated(values, truncrank(2; by=-))) == [5, 4]
34+
@test @constinferred(findtruncated(values, truncrank(2; by=((-) abs)))) == [5, 4]
35+
@test @constinferred(findtruncated_sorted(values, truncrank(2))) === 1:2
3436

35-
values = [1, 0.9, 0.5, 0.3, 0.01]
36-
@test @constinferred(findtruncated(values, TruncationKeepAbove(0.4, 0.0))) == 1:3
37-
@test @constinferred(findtruncated(values, TruncationKeepBelow(0.4, 0.0))) == 4:5
37+
values = [1, 0.9, 0.5, -0.3, 0.01]
38+
for strategy in (TruncationKeepAbove(; atol=0.4, rtol=0),
39+
TruncationKeepAbove(0.4, 0))
40+
@test @constinferred(findtruncated(values, strategy)) == 1:3
41+
@test @constinferred(findtruncated_sorted(values, strategy)) === 1:3
42+
end
43+
for strategy in (TruncationKeepBelow(; atol=0.4, rtol=0),
44+
TruncationKeepBelow(0.4, 0))
45+
@test @constinferred(findtruncated(values, strategy)) == 4:5
46+
@test @constinferred(findtruncated_sorted(values, strategy)) === 4:5
47+
end
3848

39-
values = [0.01, 1, 0.9, 0.3, 0.5]
40-
@test @constinferred(findtruncated(values, TruncationKeepAbove(0.4, 0.0))) == [2, 3, 5]
41-
@test @constinferred(findtruncated(values, TruncationKeepBelow(0.4, 0.0))) == [1, 4]
49+
values = [0.01, 1, 0.9, -0.3, 0.5]
50+
for strategy in (TruncationKeepAbove(; atol=0.4, rtol=0),
51+
TruncationKeepAbove(; atol=0.4, rtol=0, by=abs),
52+
TruncationKeepAbove(0.4, 0),
53+
TruncationKeepAbove(; atol=0.2, rtol=0.0, by=identity))
54+
@test @constinferred(findtruncated(values, strategy)) == [2, 3, 5]
55+
end
56+
for strategy in (TruncationKeepAbove(; atol=0.2, rtol=0),
57+
TruncationKeepAbove(; atol=0.2, rtol=0, by=abs),
58+
TruncationKeepAbove(0.2, 0))
59+
@test @constinferred(findtruncated(values, strategy)) == [2, 3, 4, 5]
60+
end
61+
for strategy in (TruncationKeepBelow(; atol=0.4, rtol=0),
62+
TruncationKeepBelow(; atol=0.4, rtol=0, by=abs),
63+
TruncationKeepBelow(0.4, 0),
64+
TruncationKeepBelow(; atol=0.2, rtol=0.0, by=identity))
65+
@test @constinferred(findtruncated(values, strategy)) == [1, 4]
66+
end
67+
for strategy in (TruncationKeepBelow(; atol=0.2, rtol=0),
68+
TruncationKeepBelow(; atol=0.2, rtol=0, by=abs),
69+
TruncationKeepBelow(0.2, 0))
70+
@test @constinferred(findtruncated(values, strategy)) == [1]
71+
end
4272
end

0 commit comments

Comments
 (0)