Skip to content

Commit 19cce9d

Browse files
committed
Fix and refactor truncation tests
1 parent 3a274b7 commit 19cce9d

File tree

3 files changed

+57
-49
lines changed

3 files changed

+57
-49
lines changed

src/tensors/factorizations/truncation.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ truncspace(space::ElementarySpace) = TruncationSpace(space)
3636
function truncate!(::typeof(svd_trunc!),
3737
(U, S, Vᴴ)::Tuple{AbstractTensorMap,AbstractTensorMap,AbstractTensorMap},
3838
strategy::TruncationStrategy)
39+
strategy == notrunc() && return (U, S, Vᴴ)
3940
ind = findtruncated_sorted(diagview(S), strategy)
4041
V_truncated = spacetype(S)(c => length(I) for (c, I) in ind)
4142

@@ -66,6 +67,7 @@ end
6667
function truncate!(::typeof(left_null!),
6768
(U, S)::Tuple{AbstractTensorMap,AbstractTensorMap},
6869
strategy::MatrixAlgebraKit.TruncationStrategy)
70+
strategy == notrunc() && return (U, S)
6971
extended_S = SectorDict(c => vcat(diagview(b),
7072
zeros(eltype(b), max(0, size(b, 2) - size(b, 1))))
7173
for (c, b) in blocks(S))
@@ -82,6 +84,7 @@ for f! in (:eig_trunc!, :eigh_trunc!)
8284
@eval function truncate!(::typeof($f!),
8385
(D, V)::Tuple{AbstractTensorMap,AbstractTensorMap},
8486
strategy::TruncationStrategy)
87+
strategy == notrunc() && return (D, V)
8588
ind = findtruncated(diagview(D), strategy)
8689
V_truncated = spacetype(D)(c => length(I) for (c, I) in ind)
8790

@@ -136,10 +139,14 @@ function _findnexttruncvalue(S, truncdim::SectorDict{I,Int}; by=identity,
136139
end
137140

138141
# implementations
142+
function findtruncated_sorted(S::SectorDict, strategy::TruncationStrategy)
143+
return findtruncated(S, strategy)
144+
end
145+
139146
function findtruncated_sorted(S::SectorDict, strategy::TruncationKeepAbove)
140147
atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol)
141148
findtrunc = Base.Fix2(findtruncated_sorted, truncbelow(atol))
142-
return SectorDict(c => findtrunc(d) for (c, d) in Sd)
149+
return SectorDict(c => findtrunc(d) for (c, d) in S)
143150
end
144151
function findtruncated(S::SectorDict, strategy::TruncationKeepAbove)
145152
atol = rtol_to_atol(S, strategy.p, strategy.atol, strategy.rtol)
@@ -195,13 +202,10 @@ function findtruncated(Sd::SectorDict, strategy::TruncationKeepSorted)
195202
_, cmin = next
196203
truncdim[cmin] -= 1
197204
totaldim -= dim(cmin)
198-
if totaldim < strategy.howmany
199-
# truncdim[cmin] += 1
200-
break
201-
end
202205
if truncdim[cmin] == 0
203206
delete!(truncdim, cmin)
204207
end
208+
totaldim <= strategy.howmany && break
205209
end
206210
return SectorDict(c => permutations[c][Base.OneTo(d)] for (c, d) in truncdim)
207211
end

test/factorizations.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,54 @@ for V in spacelist
204204
end
205205
end
206206

207+
@testset "truncated SVD" begin
208+
for T in eltypes,
209+
t in (randn(T, W, W), randn(T, W, W)',
210+
randn(T, W, V1), randn(T, V1, W),
211+
randn(T, W, V1)', randn(T, V1, W)',
212+
DiagonalTensorMap(randn(T, reduceddim(V1)), V1))
213+
214+
@constinferred normalize!(t)
215+
216+
U, S, Vᴴ = @constinferred svd_trunc(t; trunc=notrunc())
217+
@test U * S * Vᴴ t
218+
@test isisometry(U)
219+
@test isisometry(Vᴴ; side=:right)
220+
221+
trunc = truncrank(dim(domain(S)) ÷ 2)
222+
U1, S1, Vᴴ1 = @constinferred svd_trunc(t; trunc)
223+
@test t * Vᴴ1' U1 * S1
224+
@test isisometry(U1)
225+
@test isisometry(Vᴴ1; side=:right)
226+
@test dim(domain(S1)) <= trunc.howmany
227+
228+
λ = minimum(minimum, values(LinearAlgebra.diag(S1)))
229+
trunc = trunctol- 10eps(λ))
230+
U2, S2, Vᴴ2 = @constinferred svd_trunc(t; trunc)
231+
@test t * Vᴴ2' U2 * S2
232+
@test isisometry(U2)
233+
@test isisometry(Vᴴ2; side=:right)
234+
@test minimum(minimum, values(LinearAlgebra.diag(S1))) >= λ
235+
@test U2 U1
236+
@test S2 S1
237+
@test Vᴴ2 Vᴴ1
238+
239+
trunc = truncspace(space(S2, 1))
240+
U3, S3, Vᴴ3 = @constinferred svd_trunc(t; trunc)
241+
@test t * Vᴴ3' U3 * S3
242+
@test isisometry(U3)
243+
@test isisometry(Vᴴ3; side=:right)
244+
@test space(S3, 1) space(S2, 1)
245+
246+
trunc = truncerr(0.5)
247+
U4, S4, Vᴴ4 = @constinferred svd_trunc(t; trunc)
248+
@test t * Vᴴ4' U4 * S4
249+
@test isisometry(U4)
250+
@test isisometry(Vᴴ4; side=:right)
251+
@test norm(t - U4 * S4 * Vᴴ4) <= 0.5
252+
end
253+
end
254+
207255
@testset "Eigenvalue decomposition" begin
208256
for T in eltypes,
209257
t in

test/tensors.jl

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -432,50 +432,6 @@ for V in spacelist
432432
@test LinearAlgebra.isdiag(D)
433433
@test LinearAlgebra.diag(D) == d
434434
end
435-
@timedtestset "Tensor truncation" begin
436-
for T in (Float32, ComplexF64)
437-
for p in (1, 2, 3, Inf)
438-
# Test both a normal tensor and an adjoint one.
439-
ts = (randn(T, V1 V2 V3, V4 V5),
440-
randn(T, V4 V5, V1 V2 V3)')
441-
for t in ts
442-
U₀, S₀, V₀, = tsvd(t)
443-
t = rmul!(t, 1 / norm(S₀, p))
444-
U, S, V = @constinferred tsvd(t; trunc=truncerr(5e-1, p))
445-
ϵ = TensorKit._norm(LinearAlgebra.svdvals(U * S * V - t), p,
446-
zero(scalartype(S)))
447-
p == 2 && @test ϵ < 5e-1
448-
# @show p, ϵ
449-
# @show domain(S)
450-
# @test min(space(S,1), space(S₀,1)) != space(S₀,1)
451-
U′, S′, V′ = tsvd(t; trunc=truncerr+ 10eps(ϵ), p))
452-
ϵ′ = TensorKit._norm(LinearAlgebra.svdvals(U′ * S′ * V′ - t), p,
453-
zero(scalartype(S)))
454-
455-
@test (U, S, V, ϵ) == (U′, S′, V′, ϵ′)
456-
U′, S′, V′ = tsvd(t; trunc=truncdim(ceil(Int, dim(domain(S)))))
457-
ϵ′ = TensorKit._norm(LinearAlgebra.svdvals(U′ * S′ * V′ - t), p,
458-
zero(scalartype(S)))
459-
@test (U, S, V, ϵ) == (U′, S′, V′, ϵ′)
460-
U′, S′, V′ = tsvd(t; trunc=truncspace(space(S, 1)))
461-
ϵ′ = TensorKit._norm(LinearAlgebra.svdvals(U′ * S′ * V′ - t), p,
462-
zero(scalartype(S)))
463-
@test (U, S, V, ϵ) == (U′, S′, V′, ϵ′)
464-
# results with truncationcutoff cannot be compared because they don't take degeneracy into account, and thus truncate differently
465-
U, S, V = tsvd(t; trunc=truncbelow(1 / dim(domain(S₀))))
466-
ϵ = TensorKit._norm(LinearAlgebra.svdvals(U * S * V - t), p,
467-
zero(scalartype(S)))
468-
# @show p, ϵ
469-
# @show domain(S)
470-
# @test min(space(S,1), space(S₀,1)) != space(S₀,1)
471-
U′, S′, V′ = tsvd(t; trunc=truncspace(space(S, 1)))
472-
ϵ′ = TensorKit._norm(LinearAlgebra.svdvals(U′ * S′ * V′ - t), p,
473-
zero(scalartype(S)))
474-
@test (U, S, V, ϵ) == (U′, S′, V′, ϵ′)
475-
end
476-
end
477-
end
478-
end
479435
if BraidingStyle(I) isa Bosonic && hasfusiontensor(I)
480436
@timedtestset "Tensor functions" begin
481437
W = V1 V2

0 commit comments

Comments
 (0)