Skip to content

Commit bb930c7

Browse files
authored
Truncation strategy for degenerate subspaces (#60)
1 parent f166200 commit bb930c7

File tree

3 files changed

+192
-4
lines changed

3 files changed

+192
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.4"
4+
version = "0.3.5"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/MatrixAlgebra.jl

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ export eigen,
2020
svdvals!,
2121
truncerr
2222

23-
using LinearAlgebra: LinearAlgebra
23+
using LinearAlgebra: LinearAlgebra, norm
2424
using MatrixAlgebraKit
2525

2626
for (f, f_full, f_compact) in (
@@ -173,4 +173,56 @@ function MatrixAlgebraKit.findtruncated(values::AbstractVector, strategy::Trunca
173173
return Base.OneTo(rank)
174174
end
175175

176+
struct TruncationDegenerate{Strategy<:TruncationStrategy,T<:Real} <: TruncationStrategy
177+
strategy::Strategy
178+
atol::T
179+
rtol::T
180+
end
181+
182+
"""
183+
truncdegen(trunc::TruncationStrategy; atol::Real=0, rtol::Real=0)
184+
185+
Modify a truncation strategy so that if the truncation falls within
186+
a degenerate subspace, the entire subspace gets truncated as well.
187+
A value `val` is considered degenerate if
188+
`norm(val - truncval) ≤ max(atol, rtol * norm(truncval))`
189+
where `truncval` is the largest value truncated by the original
190+
truncation strategy `trunc`.
191+
192+
For now, this truncation strategy assumes the spectrum being truncated
193+
has already been reverse sorted and the strategy being wrapped
194+
outputs a contiguous subset of values including the largest one. It
195+
also only truncates for now, so may not respect if a minimum dimension
196+
was requested in the strategy being wrapped. These restrictions may
197+
be lifted in the future or provided through a different truncation strategy.
198+
"""
199+
function truncdegen(strategy::TruncationStrategy; atol::Real=0, rtol::Real=0)
200+
return TruncationDegenerate(strategy, promote(atol, rtol)...)
201+
end
202+
203+
using MatrixAlgebraKit: findtruncated
204+
205+
function MatrixAlgebraKit.findtruncated(
206+
values::AbstractVector, strategy::TruncationDegenerate
207+
)
208+
Base.require_one_based_indexing(values)
209+
issorted(values; rev=true) || throw(ArgumentError("Values must be reverse sorted."))
210+
indices_collection = findtruncated(values, strategy.strategy)
211+
indices = Base.OneTo(maximum(indices_collection))
212+
indices_collection == indices ||
213+
throw(ArgumentError("Truncation must be a contiguous range."))
214+
if length(indices_collection) == length(values)
215+
# No truncation occurred.
216+
return indices
217+
end
218+
# The largest truncated value.
219+
truncval = values[last(indices) + 1]
220+
# Tolerance of determining if a value is degenerate.
221+
atol = max(strategy.atol, strategy.rtol * abs(truncval))
222+
for rank in reverse(indices)
223+
(values[rank], truncval; atol, rtol=0) || return Base.OneTo(rank)
224+
end
225+
return Base.OneTo(0)
226+
end
227+
176228
end

test/test_matrixalgebra.jl

Lines changed: 138 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using LinearAlgebra: Diagonal, I, diag, isposdef, norm
2-
using MatrixAlgebraKit: qr_compact, svd_trunc
2+
using MatrixAlgebraKit: qr_compact, svd_trunc, truncrank
33
using StableRNGs: StableRNG
4-
using TensorAlgebra.MatrixAlgebra: MatrixAlgebra, truncerr
4+
using TensorAlgebra.MatrixAlgebra: MatrixAlgebra, truncdegen, truncerr
55
using Test: @test, @testset
66

77
elts = (Float32, Float64, ComplexF32, ComplexF64)
@@ -304,4 +304,140 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
304304
@test norm(ũ **- a) norm([0.001])
305305
@test** a atol = 0.002 rtol = 0.002
306306
end
307+
@testset "Truncate degenerate" begin
308+
s = Diagonal(real(elt)[2.0, 0.32, 0.3, 0.29, 0.01, 0.01])
309+
n = length(diag(s))
310+
rng = StableRNG(123)
311+
u, _ = qr_compact(randn(rng, elt, n, n); positive=true)
312+
v, _ = qr_compact(randn(rng, elt, n, n); positive=true)
313+
a = u * s * v
314+
315+
ũ, s̃, ṽ = svd_trunc(a; trunc=truncdegen(truncrank(n); atol=0.1))
316+
@test size(ũ) == (n, n)
317+
@test size(s̃) == (n, n)
318+
@test size(ṽ) == (n, n)
319+
@test** a
320+
321+
for kwargs in (
322+
(; atol=eps(real(elt))),
323+
(; rtol=(eps(real(elt)))),
324+
(; atol=eps(real(elt)), rtol=(eps(real(elt)))),
325+
)
326+
ũ, s̃, ṽ = svd_trunc(a; trunc=truncdegen(truncrank(5); kwargs...))
327+
@test size(ũ) == (n, 4)
328+
@test size(s̃) == (4, 4)
329+
@test size(ṽ) == (4, n)
330+
@test norm(ũ **- a) norm([0.01, 0.01])
331+
end
332+
333+
for kwargs in (
334+
(; atol=eps(real(elt))),
335+
(; rtol=eps(real(elt))),
336+
(; atol=eps(real(elt)), rtol=eps(real(elt))),
337+
)
338+
ũ, s̃, ṽ = svd_trunc(a; trunc=truncdegen(truncrank(4); kwargs...))
339+
@test size(ũ) == (n, 4)
340+
@test size(s̃) == (4, 4)
341+
@test size(ṽ) == (4, n)
342+
@test norm(ũ **- a) norm([0.01, 0.01])
343+
end
344+
345+
trunc = truncdegen(truncrank(3); atol=0.01 - eps(real(elt)))
346+
ũ, s̃, ṽ = svd_trunc(a; trunc)
347+
@test size(ũ) == (n, 3)
348+
@test size(s̃) == (3, 3)
349+
@test size(ṽ) == (3, n)
350+
@test norm(ũ **- a) norm([0.29, 0.01, 0.01])
351+
352+
trunc = truncdegen(truncrank(3); rtol=0.01/0.3 - eps(real(elt)))
353+
ũ, s̃, ṽ = svd_trunc(a; trunc)
354+
@test size(ũ) == (n, 3)
355+
@test size(s̃) == (3, 3)
356+
@test size(ṽ) == (3, n)
357+
@test norm(ũ **- a) norm([0.29, 0.01, 0.01])
358+
359+
trunc = truncdegen(truncrank(3); atol=0.01 + eps(real(elt)))
360+
ũ, s̃, ṽ = svd_trunc(a; trunc)
361+
@test size(ũ) == (n, 2)
362+
@test size(s̃) == (2, 2)
363+
@test size(ṽ) == (2, n)
364+
@test norm(ũ **- a) norm([0.3, 0.29, 0.01, 0.01])
365+
366+
trunc = truncdegen(truncrank(3); rtol=0.01/0.29 + eps(real(elt)))
367+
ũ, s̃, ṽ = svd_trunc(a; trunc)
368+
@test size(ũ) == (n, 2)
369+
@test size(s̃) == (2, 2)
370+
@test size(ṽ) == (2, n)
371+
@test norm(ũ **- a) norm([0.3, 0.29, 0.01, 0.01])
372+
373+
trunc = truncdegen(truncrank(3); atol=0.02 - eps(real(elt)))
374+
ũ, s̃, ṽ = svd_trunc(a; trunc)
375+
@test size(ũ) == (n, 2)
376+
@test size(s̃) == (2, 2)
377+
@test size(ṽ) == (2, n)
378+
@test norm(ũ **- a) norm([0.3, 0.29, 0.01, 0.01])
379+
380+
trunc = truncdegen(truncrank(3); rtol=0.02/0.29 - eps(real(elt)))
381+
ũ, s̃, ṽ = svd_trunc(a; trunc)
382+
@test size(ũ) == (n, 2)
383+
@test size(s̃) == (2, 2)
384+
@test size(ṽ) == (2, n)
385+
@test norm(ũ **- a) norm([0.3, 0.29, 0.01, 0.01])
386+
387+
trunc = truncdegen(truncrank(3); atol=0.03 + eps(real(elt)))
388+
ũ, s̃, ṽ = svd_trunc(a; trunc)
389+
@test size(ũ) == (n, 1)
390+
@test size(s̃) == (1, 1)
391+
@test size(ṽ) == (1, n)
392+
@test norm(ũ **- a) norm([0.32, 0.3, 0.29, 0.01, 0.01])
393+
394+
trunc = truncdegen(truncrank(3); rtol=0.03/0.29 + eps(real(elt)))
395+
ũ, s̃, ṽ = svd_trunc(a; trunc)
396+
@test size(ũ) == (n, 1)
397+
@test size(s̃) == (1, 1)
398+
@test size(ṽ) == (1, n)
399+
@test norm(ũ **- a) norm([0.32, 0.3, 0.29, 0.01, 0.01])
400+
401+
trunc = truncdegen(truncrank(3); atol=0.01, rtol=0.03/0.29 + eps(real(elt)))
402+
ũ, s̃, ṽ = svd_trunc(a; trunc)
403+
@test size(ũ) == (n, 1)
404+
@test size(s̃) == (1, 1)
405+
@test size(ṽ) == (1, n)
406+
@test norm(ũ **- a) norm([0.32, 0.3, 0.29, 0.01, 0.01])
407+
408+
trunc = truncdegen(truncrank(3); atol=0.03 + eps(real(elt)), rtol=0.01/0.29)
409+
ũ, s̃, ṽ = svd_trunc(a; trunc)
410+
@test size(ũ) == (n, 1)
411+
@test size(s̃) == (1, 1)
412+
@test size(ṽ) == (1, n)
413+
@test norm(ũ **- a) norm([0.32, 0.3, 0.29, 0.01, 0.01])
414+
415+
trunc = truncdegen(truncrank(3); atol=(2 - 0.29) - (eps(real(elt))))
416+
ũ, s̃, ṽ = svd_trunc(a; trunc)
417+
@test size(ũ) == (n, 1)
418+
@test size(s̃) == (1, 1)
419+
@test size(ṽ) == (1, n)
420+
@test norm(ũ **- a) norm([0.32, 0.3, 0.29, 0.01, 0.01])
421+
422+
trunc = truncdegen(truncrank(3); rtol=(2 - 0.29)/0.29 - (eps(real(elt))))
423+
ũ, s̃, ṽ = svd_trunc(a; trunc)
424+
@test size(ũ) == (n, 1)
425+
@test size(s̃) == (1, 1)
426+
@test size(ṽ) == (1, n)
427+
@test norm(ũ **- a) norm([0.32, 0.3, 0.29, 0.01, 0.01])
428+
429+
trunc = truncdegen(truncrank(3); atol=(2 - 0.29) + (eps(real(elt))))
430+
ũ, s̃, ṽ = svd_trunc(a; trunc)
431+
@test size(ũ) == (n, 0)
432+
@test size(s̃) == (0, 0)
433+
@test size(ṽ) == (0, n)
434+
@test norm(ũ ** ṽ) 0
435+
436+
trunc = truncdegen(truncrank(3); rtol=(2 - 0.29)/0.29 + (eps(real(elt))))
437+
ũ, s̃, ṽ = svd_trunc(a; trunc)
438+
@test size(ũ) == (n, 0)
439+
@test size(s̃) == (0, 0)
440+
@test size(ṽ) == (0, n)
441+
@test norm(ũ ** ṽ) 0
442+
end
307443
end

0 commit comments

Comments
 (0)