Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.4"
version = "0.3.5"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
54 changes: 53 additions & 1 deletion src/MatrixAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export eigen,
svdvals!,
truncerr

using LinearAlgebra: LinearAlgebra
using LinearAlgebra: LinearAlgebra, norm
using MatrixAlgebraKit

for (f, f_full, f_compact) in (
Expand Down Expand Up @@ -173,4 +173,56 @@ function MatrixAlgebraKit.findtruncated(values::AbstractVector, strategy::Trunca
return Base.OneTo(rank)
end

struct TruncationDegenerate{Strategy<:TruncationStrategy,T<:Real} <: TruncationStrategy
strategy::Strategy
atol::T
rtol::T
end

"""
truncdegen(trunc::TruncationStrategy; atol::Real=0, rtol::Real=0)

Modify a truncation strategy so that if the truncation falls within
a degenerate subspace, the entire subspace gets truncated as well.
A value `val` is considered degenerate if
`norm(val - truncval) ≤ max(atol, rtol * norm(truncval))`
where `truncval` is the largest value truncated by the original
truncation strategy `trunc`.

For now, this truncation strategy assumes the spectrum being truncated
has already been reverse sorted and the strategy being wrapped
outputs a contiguous subset of values including the largest one. It
also only truncates for now, so may not respect if a minimum dimension
was requested in the strategy being wrapped. These restrictions may
be lifted in the future or provided through a different truncation strategy.
"""
function truncdegen(strategy::TruncationStrategy; atol::Real=0, rtol::Real=0)
return TruncationDegenerate(strategy, promote(atol, rtol)...)
end

using MatrixAlgebraKit: findtruncated

function MatrixAlgebraKit.findtruncated(
values::AbstractVector, strategy::TruncationDegenerate
)
Base.require_one_based_indexing(values)
issorted(values; rev=true) || throw(ArgumentError("Values must be reverse sorted."))
indices_collection = findtruncated(values, strategy.strategy)
indices = Base.OneTo(maximum(indices_collection))
indices_collection == indices ||
throw(ArgumentError("Truncation must be a contiguous range."))
if length(indices_collection) == length(values)
# No truncation occurred.
return indices
end
# The largest truncated value.
truncval = values[last(indices) + 1]
# Tolerance of determining if a value is degenerate.
atol = max(strategy.atol, strategy.rtol * abs(truncval))
for rank in reverse(indices)
≈(values[rank], truncval; atol, rtol=0) || return Base.OneTo(rank)
end
return Base.OneTo(0)
end

end
140 changes: 138 additions & 2 deletions test/test_matrixalgebra.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using LinearAlgebra: Diagonal, I, diag, isposdef, norm
using MatrixAlgebraKit: qr_compact, svd_trunc
using MatrixAlgebraKit: qr_compact, svd_trunc, truncrank
using StableRNGs: StableRNG
using TensorAlgebra.MatrixAlgebra: MatrixAlgebra, truncerr
using TensorAlgebra.MatrixAlgebra: MatrixAlgebra, truncdegen, truncerr
using Test: @test, @testset

elts = (Float32, Float64, ComplexF32, ComplexF64)
Expand Down Expand Up @@ -304,4 +304,140 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.001])
@test ũ * s̃ * ṽ ≈ a atol = 0.002 rtol = 0.002
end
@testset "Truncate degenerate" begin
s = Diagonal(real(elt)[2.0, 0.32, 0.3, 0.29, 0.01, 0.01])
n = length(diag(s))
rng = StableRNG(123)
u, _ = qr_compact(randn(rng, elt, n, n); positive=true)
v, _ = qr_compact(randn(rng, elt, n, n); positive=true)
a = u * s * v

ũ, s̃, ṽ = svd_trunc(a; trunc=truncdegen(truncrank(n); atol=0.1))
@test size(ũ) == (n, n)
@test size(s̃) == (n, n)
@test size(ṽ) == (n, n)
@test ũ * s̃ * ṽ ≈ a

for kwargs in (
(; atol=eps(real(elt))),
(; rtol=(√eps(real(elt)))),
(; atol=eps(real(elt)), rtol=(√eps(real(elt)))),
)
ũ, s̃, ṽ = svd_trunc(a; trunc=truncdegen(truncrank(5); kwargs...))
@test size(ũ) == (n, 4)
@test size(s̃) == (4, 4)
@test size(ṽ) == (4, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.01, 0.01])
end

for kwargs in (
(; atol=eps(real(elt))),
(; rtol=eps(real(elt))),
(; atol=eps(real(elt)), rtol=eps(real(elt))),
)
ũ, s̃, ṽ = svd_trunc(a; trunc=truncdegen(truncrank(4); kwargs...))
@test size(ũ) == (n, 4)
@test size(s̃) == (4, 4)
@test size(ṽ) == (4, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.01, 0.01])
end

trunc = truncdegen(truncrank(3); atol=0.01 - √eps(real(elt)))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 3)
@test size(s̃) == (3, 3)
@test size(ṽ) == (3, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); rtol=0.01/0.3 - √eps(real(elt)))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 3)
@test size(s̃) == (3, 3)
@test size(ṽ) == (3, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); atol=0.01 + √eps(real(elt)))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 2)
@test size(s̃) == (2, 2)
@test size(ṽ) == (2, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); rtol=0.01/0.29 + √eps(real(elt)))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 2)
@test size(s̃) == (2, 2)
@test size(ṽ) == (2, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); atol=0.02 - √eps(real(elt)))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 2)
@test size(s̃) == (2, 2)
@test size(ṽ) == (2, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); rtol=0.02/0.29 - √eps(real(elt)))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 2)
@test size(s̃) == (2, 2)
@test size(ṽ) == (2, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); atol=0.03 + √eps(real(elt)))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 1)
@test size(s̃) == (1, 1)
@test size(ṽ) == (1, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); rtol=0.03/0.29 + √eps(real(elt)))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 1)
@test size(s̃) == (1, 1)
@test size(ṽ) == (1, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); atol=0.01, rtol=0.03/0.29 + √eps(real(elt)))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 1)
@test size(s̃) == (1, 1)
@test size(ṽ) == (1, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); atol=0.03 + √eps(real(elt)), rtol=0.01/0.29)
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 1)
@test size(s̃) == (1, 1)
@test size(ṽ) == (1, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); atol=(2 - 0.29) - √(eps(real(elt))))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 1)
@test size(s̃) == (1, 1)
@test size(ṽ) == (1, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); rtol=(2 - 0.29)/0.29 - √(eps(real(elt))))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 1)
@test size(s̃) == (1, 1)
@test size(ṽ) == (1, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); atol=(2 - 0.29) + √(eps(real(elt))))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 0)
@test size(s̃) == (0, 0)
@test size(ṽ) == (0, n)
@test norm(ũ * s̃ * ṽ) ≈ 0

trunc = truncdegen(truncrank(3); rtol=(2 - 0.29)/0.29 + √(eps(real(elt))))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 0)
@test size(s̃) == (0, 0)
@test size(ṽ) == (0, n)
@test norm(ũ * s̃ * ṽ) ≈ 0
end
end
Loading