Skip to content

Commit e9df0b2

Browse files
authored
Truncation composition (#18)
* Truncation composition * Add tests for truncation objects * Fix truncated SVD test logic * Change name to TruncationIntersection
1 parent ec40b48 commit e9df0b2

File tree

5 files changed

+101
-5
lines changed

5 files changed

+101
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MatrixAlgebraKit"
22
uuid = "6c742aac-3347-4629-af66-fc926824e5e4"
33
authors = ["Jutho <[email protected]> and contributors"]
4-
version = "0.1.1"
4+
version = "0.1.2"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/implementations/truncation.jl

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,17 @@ function TruncationStrategy(; atol=nothing, rtol=nothing, maxrank=nothing)
1111
if isnothing(maxrank) && isnothing(atol) && isnothing(rtol)
1212
return NoTruncation()
1313
elseif isnothing(maxrank)
14-
@assert isnothing(rtol) "TODO: rtol"
15-
return trunctol(atol)
14+
atol = @something atol 0
15+
rtol = @something rtol 0
16+
return TruncationKeepAbove(atol, rtol)
1617
else
17-
return truncrank(maxrank)
18+
if isnothing(atol) && isnothing(rtol)
19+
return truncrank(maxrank)
20+
else
21+
atol = @something atol 0
22+
rtol = @something rtol 0
23+
return truncrank(maxrank) & TruncationKeepAbove(atol, rtol)
24+
end
1825
end
1926
end
2027

@@ -82,6 +89,28 @@ Truncation strategy to discard the values that are larger than `atol` in absolut
8289
"""
8390
truncabove(atol) = TruncationKeepFiltered((atol) abs)
8491

92+
"""
93+
TruncationIntersection(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
94+
95+
Compose two truncation strategies, keeping values common between the two strategies.
96+
"""
97+
struct TruncationIntersection{T<:Tuple{Vararg{TruncationStrategy}}} <:
98+
TruncationStrategy
99+
components::T
100+
end
101+
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
102+
return TruncationIntersection((trunc1, trunc2))
103+
end
104+
function Base.:&(trunc1::TruncationIntersection, trunc2::TruncationIntersection)
105+
return TruncationIntersection((trunc1.components..., trunc2.components...))
106+
end
107+
function Base.:&(trunc1::TruncationIntersection, trunc2::TruncationStrategy)
108+
return TruncationIntersection((trunc1.components..., trunc2))
109+
end
110+
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationIntersection)
111+
return TruncationIntersection((trunc1, trunc2.components...))
112+
end
113+
85114
# truncate!
86115
# ---------
87116
# Generic implementation: `findtruncated` followed by indexing
@@ -147,6 +176,11 @@ function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove)
147176
return 1:i
148177
end
149178

179+
function findtruncated(values::AbstractVector, strategy::TruncationIntersection)
180+
inds = map(Base.Fix1(findtruncated, values), strategy.components)
181+
return intersect(inds...)
182+
end
183+
150184
"""
151185
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)
152186

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
using SafeTestsets
22

3+
@safetestset "Truncate" begin
4+
include("truncate.jl")
5+
end
36
@safetestset "QR / LQ Decomposition" begin
47
include("qr.jl")
58
include("lq.jl")

test/svd.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33
using TestExtras
44
using StableRNGs
55
using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef
6-
using MatrixAlgebraKit: diagview
6+
using MatrixAlgebraKit: TruncationKeepAbove, diagview
77

88
@testset "svd_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
99
rng = StableRNG(123)
@@ -115,3 +115,33 @@ end
115115
end
116116
end
117117
end
118+
119+
@testset "svd_trunc! mix maxrank and tol for T = $T" for T in
120+
(Float32, Float64, ComplexF32,
121+
ComplexF64)
122+
rng = StableRNG(123)
123+
if LinearAlgebra.LAPACK.version() < v"3.12.0"
124+
algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection())
125+
else
126+
algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(),
127+
LAPACK_Jacobi())
128+
end
129+
m = 4
130+
@testset "algorithm $alg" for alg in algs
131+
U = qr_compact(randn(rng, T, m, m))[1]
132+
S = Diagonal([0.9, 0.3, 0.1, 0.01])
133+
Vᴴ = qr_compact(randn(rng, T, m, m))[1]
134+
A = U * S * Vᴴ
135+
136+
for trunc_fun in ((rtol, maxrank) -> (; rtol, maxrank),
137+
(rtol, maxrank) -> truncrank(maxrank) & TruncationKeepAbove(0, rtol))
138+
U1, S1, V1ᴴ = svd_trunc(A; alg, trunc=trunc_fun(0.2, 1))
139+
@test length(S1.diag) == 1
140+
@test S1.diag ≈ S.diag[1:1] rtol = sqrt(eps(real(T)))
141+
142+
U2, S2, V2ᴴ = svd_trunc(A; alg, trunc=trunc_fun(0.2, 3))
143+
@test length(S2.diag) == 2
144+
@test S2.diag ≈ S.diag[1:2] rtol = sqrt(eps(real(T)))
145+
end
146+
end
147+
end

test/truncate.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using TestExtras
4+
using MatrixAlgebraKit: NoTruncation, TruncationIntersection, TruncationKeepAbove,
5+
TruncationStrategy
6+
7+
@testset "truncate" begin
8+
trunc = @constinferred TruncationStrategy()
9+
@test trunc isa NoTruncation
10+
11+
trunc = @constinferred TruncationStrategy(; atol=1e-2, rtol=1e-3)
12+
@test trunc isa TruncationKeepAbove
13+
@test trunc == TruncationKeepAbove(1e-2, 1e-3)
14+
@test trunc.atol == 1e-2
15+
@test trunc.rtol == 1e-3
16+
17+
trunc = @constinferred TruncationStrategy(; maxrank=10)
18+
@test trunc isa TruncationKeepSorted
19+
@test trunc == truncrank(10)
20+
@test trunc.howmany == 10
21+
@test trunc.sortby == abs
22+
@test trunc.rev == true
23+
24+
trunc = @constinferred TruncationStrategy(; atol=1e-2, rtol=1e-3, maxrank=10)
25+
@test trunc isa TruncationIntersection
26+
@test trunc == truncrank(10) & TruncationKeepAbove(1e-2, 1e-3)
27+
@test trunc.components[1] == truncrank(10)
28+
@test trunc.components[2] == TruncationKeepAbove(1e-2, 1e-3)
29+
end

0 commit comments

Comments
 (0)