Skip to content

Commit 5adbbff

Browse files
committed
Truncation composition
1 parent 47cef6e commit 5adbbff

File tree

9 files changed

+63
-10
lines changed

9 files changed

+63
-10
lines changed

ext/MatrixAlgebraKitChainRulesCoreExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,4 +176,4 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A::AbstractMatrix, PWᴴ,
176176
return PWᴴ, right_polar_pullback
177177
end
178178

179-
end
179+
end

src/algorithms.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,4 +178,4 @@ macro check_size(x, sz, size=:size)
178178
string($sz)
179179
szx == $sz || throw(DimensionMismatch($err))
180180
end)
181-
end
181+
end

src/implementations/orthnull.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,4 +215,4 @@ function right_null!(A::AbstractMatrix, Nᴴ; kwargs...)
215215
else
216216
throw(ArgumentError("`right_null!` received unknown value `kind = $kind`"))
217217
end
218-
end
218+
end

src/implementations/truncation.jl

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,17 @@ function TruncationStrategy(; atol=nothing, rtol=nothing, maxrank=nothing)
1111
if isnothing(maxrank) && isnothing(atol) && isnothing(rtol)
1212
return NoTruncation()
1313
elseif isnothing(maxrank)
14-
@assert isnothing(rtol) "TODO: rtol"
15-
return trunctol(atol)
14+
atol = @something atol 0
15+
rtol = @something rtol 0
16+
return TruncationKeepAbove(atol, rtol)
1617
else
17-
return truncrank(maxrank)
18+
if isnothing(atol) && isnothing(rtol)
19+
return truncrank(maxrank)
20+
else
21+
atol = @something atol 0
22+
rtol = @something rtol 0
23+
return truncrank(maxrank) & TruncationKeepAbove(atol, rtol)
24+
end
1825
end
1926
end
2027

@@ -82,6 +89,20 @@ Truncation strategy to discard the values that are larger than `atol` in absolut
8289
"""
8390
truncabove(atol) = TruncationKeepFiltered((atol) abs)
8491

92+
"""
93+
TruncationComposition(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
94+
95+
Compose two truncation strategies, keeping values common between the two strategies.
96+
"""
97+
struct TruncationComposition{T1<:TruncationStrategy,T2<:TruncationStrategy} <:
98+
TruncationStrategy
99+
trunc1::T1
100+
trunc2::T2
101+
end
102+
function Base.:&(trunc1::TruncationStrategy, trunc2::TruncationStrategy)
103+
return TruncationComposition(trunc1, trunc2)
104+
end
105+
85106
# truncate!
86107
# ---------
87108
# Generic implementation: `findtruncated` followed by indexing
@@ -147,6 +168,12 @@ function findtruncated(values::AbstractVector, strategy::TruncationKeepAbove)
147168
return 1:i
148169
end
149170

171+
function findtruncated(values::AbstractVector, strategy::TruncationComposition)
172+
ind1 = findtruncated(values, strategy.trunc1)
173+
ind2 = findtruncated(values, strategy.trunc2)
174+
return ind1 ind2
175+
end
176+
150177
"""
151178
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)
152179

src/interface/orthnull.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,4 +224,4 @@ function right_null!(A::AbstractMatrix; kwargs...)
224224
end
225225
function right_null(A::AbstractMatrix; kwargs...)
226226
return right_null!(copy_input(right_null, A); kwargs...)
227-
end
227+
end

src/pullbacks/polar.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,4 @@ function right_polar_pullback!(ΔA::AbstractMatrix, PWᴴ, ΔPWᴴ)
5858
ΔA .+= PΔWᴴ
5959
end
6060
return ΔA
61-
end
61+
end

test/chainrules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,4 +356,4 @@ end
356356
test_rrule(config, right_null, A; fkwargs=(; kind=:lqpos), output_tangent=ΔNᴴ,
357357
atol=atol, rtol=rtol, rrule_f=rrule_via_ad, check_inferred=false)
358358
end
359-
end
359+
end

test/orthnull.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,4 +209,4 @@ end
209209
end
210210
end
211211
end
212-
end
212+
end

test/svd.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,29 @@ end
108108
end
109109
end
110110
end
111+
112+
@testset "svd_trunc! mix maxrank and tol for T = $T" for T in (Float32, Float64, ComplexF32,
113+
ComplexF64)
114+
rng = StableRNG(123)
115+
if LinearAlgebra.LAPACK.version() < v"3.12.0"
116+
algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection())
117+
else
118+
algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(),
119+
LAPACK_Jacobi())
120+
end
121+
m = 4
122+
@testset "algorithm $alg" for alg in algs
123+
U = qr_compact(randn(rng, T, m, m))[1]
124+
S = Diagonal([0.9, 0.3, 0.1, 0.01])
125+
Vᴴ = qr_compact(randn(rng, T, m, m))[1]
126+
A = U * S * Vᴴ
127+
128+
U1, S1, V1ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=1))
129+
@test length(S1.diag) == 1
130+
@test S1.diag ≈ S.diag[1:1] rtol = sqrt(eps(real(T)))
131+
132+
U2, S2, V2ᴴ = svd_trunc(A; alg, trunc=(; rtol=0.2, maxrank=3))
133+
@test length(S2.diag) == 2
134+
@test S2.diag ≈ S.diag[1:2] rtol = sqrt(eps(real(T)))
135+
end
136+
end

0 commit comments

Comments
 (0)