diff --git a/Project.toml b/Project.toml index 8751572..5e8781e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.3.4" +version = "0.3.5" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index d99d5e1..0731a4b 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -20,7 +20,7 @@ export eigen, svdvals!, truncerr -using LinearAlgebra: LinearAlgebra +using LinearAlgebra: LinearAlgebra, norm using MatrixAlgebraKit for (f, f_full, f_compact) in ( @@ -173,4 +173,56 @@ function MatrixAlgebraKit.findtruncated(values::AbstractVector, strategy::Trunca return Base.OneTo(rank) end +struct TruncationDegenerate{Strategy<:TruncationStrategy,T<:Real} <: TruncationStrategy + strategy::Strategy + atol::T + rtol::T +end + +""" + truncdegen(trunc::TruncationStrategy; atol::Real=0, rtol::Real=0) + +Modify a truncation strategy so that if the truncation falls within +a degenerate subspace, the entire subspace gets truncated as well. +A value `val` is considered degenerate if +`norm(val - truncval) ≤ max(atol, rtol * norm(truncval))` +where `truncval` is the largest value truncated by the original +truncation strategy `trunc`. + +For now, this truncation strategy assumes the spectrum being truncated +has already been reverse sorted and the strategy being wrapped +outputs a contiguous subset of values including the largest one. It +also only truncates for now, so may not respect if a minimum dimension +was requested in the strategy being wrapped. These restrictions may +be lifted in the future or provided through a different truncation strategy. +""" +function truncdegen(strategy::TruncationStrategy; atol::Real=0, rtol::Real=0) + return TruncationDegenerate(strategy, promote(atol, rtol)...) +end + +using MatrixAlgebraKit: findtruncated + +function MatrixAlgebraKit.findtruncated( + values::AbstractVector, strategy::TruncationDegenerate +) + Base.require_one_based_indexing(values) + issorted(values; rev=true) || throw(ArgumentError("Values must be reverse sorted.")) + indices_collection = findtruncated(values, strategy.strategy) + indices = Base.OneTo(maximum(indices_collection)) + indices_collection == indices || + throw(ArgumentError("Truncation must be a contiguous range.")) + if length(indices_collection) == length(values) + # No truncation occurred. + return indices + end + # The largest truncated value. + truncval = values[last(indices) + 1] + # Tolerance of determining if a value is degenerate. + atol = max(strategy.atol, strategy.rtol * abs(truncval)) + for rank in reverse(indices) + ≈(values[rank], truncval; atol, rtol=0) || return Base.OneTo(rank) + end + return Base.OneTo(0) +end + end diff --git a/test/test_matrixalgebra.jl b/test/test_matrixalgebra.jl index 012ea10..d95e51b 100644 --- a/test/test_matrixalgebra.jl +++ b/test/test_matrixalgebra.jl @@ -1,7 +1,7 @@ using LinearAlgebra: Diagonal, I, diag, isposdef, norm -using MatrixAlgebraKit: qr_compact, svd_trunc +using MatrixAlgebraKit: qr_compact, svd_trunc, truncrank using StableRNGs: StableRNG -using TensorAlgebra.MatrixAlgebra: MatrixAlgebra, truncerr +using TensorAlgebra.MatrixAlgebra: MatrixAlgebra, truncdegen, truncerr using Test: @test, @testset elts = (Float32, Float64, ComplexF32, ComplexF64) @@ -304,4 +304,140 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.001]) @test ũ * s̃ * ṽ ≈ a atol = 0.002 rtol = 0.002 end + @testset "Truncate degenerate" begin + s = Diagonal(real(elt)[2.0, 0.32, 0.3, 0.29, 0.01, 0.01]) + n = length(diag(s)) + rng = StableRNG(123) + u, _ = qr_compact(randn(rng, elt, n, n); positive=true) + v, _ = qr_compact(randn(rng, elt, n, n); positive=true) + a = u * s * v + + ũ, s̃, ṽ = svd_trunc(a; trunc=truncdegen(truncrank(n); atol=0.1)) + @test size(ũ) == (n, n) + @test size(s̃) == (n, n) + @test size(ṽ) == (n, n) + @test ũ * s̃ * ṽ ≈ a + + for kwargs in ( + (; atol=eps(real(elt))), + (; rtol=(√eps(real(elt)))), + (; atol=eps(real(elt)), rtol=(√eps(real(elt)))), + ) + ũ, s̃, ṽ = svd_trunc(a; trunc=truncdegen(truncrank(5); kwargs...)) + @test size(ũ) == (n, 4) + @test size(s̃) == (4, 4) + @test size(ṽ) == (4, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.01, 0.01]) + end + + for kwargs in ( + (; atol=eps(real(elt))), + (; rtol=eps(real(elt))), + (; atol=eps(real(elt)), rtol=eps(real(elt))), + ) + ũ, s̃, ṽ = svd_trunc(a; trunc=truncdegen(truncrank(4); kwargs...)) + @test size(ũ) == (n, 4) + @test size(s̃) == (4, 4) + @test size(ṽ) == (4, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.01, 0.01]) + end + + trunc = truncdegen(truncrank(3); atol=0.01 - √eps(real(elt))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 3) + @test size(s̃) == (3, 3) + @test size(ṽ) == (3, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.29, 0.01, 0.01]) + + trunc = truncdegen(truncrank(3); rtol=0.01/0.3 - √eps(real(elt))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 3) + @test size(s̃) == (3, 3) + @test size(ṽ) == (3, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.29, 0.01, 0.01]) + + trunc = truncdegen(truncrank(3); atol=0.01 + √eps(real(elt))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 2) + @test size(s̃) == (2, 2) + @test size(ṽ) == (2, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01]) + + trunc = truncdegen(truncrank(3); rtol=0.01/0.29 + √eps(real(elt))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 2) + @test size(s̃) == (2, 2) + @test size(ṽ) == (2, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01]) + + trunc = truncdegen(truncrank(3); atol=0.02 - √eps(real(elt))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 2) + @test size(s̃) == (2, 2) + @test size(ṽ) == (2, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01]) + + trunc = truncdegen(truncrank(3); rtol=0.02/0.29 - √eps(real(elt))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 2) + @test size(s̃) == (2, 2) + @test size(ṽ) == (2, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01]) + + trunc = truncdegen(truncrank(3); atol=0.03 + √eps(real(elt))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 1) + @test size(s̃) == (1, 1) + @test size(ṽ) == (1, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) + + trunc = truncdegen(truncrank(3); rtol=0.03/0.29 + √eps(real(elt))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 1) + @test size(s̃) == (1, 1) + @test size(ṽ) == (1, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) + + trunc = truncdegen(truncrank(3); atol=0.01, rtol=0.03/0.29 + √eps(real(elt))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 1) + @test size(s̃) == (1, 1) + @test size(ṽ) == (1, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) + + trunc = truncdegen(truncrank(3); atol=0.03 + √eps(real(elt)), rtol=0.01/0.29) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 1) + @test size(s̃) == (1, 1) + @test size(ṽ) == (1, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) + + trunc = truncdegen(truncrank(3); atol=(2 - 0.29) - √(eps(real(elt)))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 1) + @test size(s̃) == (1, 1) + @test size(ṽ) == (1, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) + + trunc = truncdegen(truncrank(3); rtol=(2 - 0.29)/0.29 - √(eps(real(elt)))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 1) + @test size(s̃) == (1, 1) + @test size(ṽ) == (1, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) + + trunc = truncdegen(truncrank(3); atol=(2 - 0.29) + √(eps(real(elt)))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 0) + @test size(s̃) == (0, 0) + @test size(ṽ) == (0, n) + @test norm(ũ * s̃ * ṽ) ≈ 0 + + trunc = truncdegen(truncrank(3); rtol=(2 - 0.29)/0.29 + √(eps(real(elt)))) + ũ, s̃, ṽ = svd_trunc(a; trunc) + @test size(ũ) == (n, 0) + @test size(s̃) == (0, 0) + @test size(ṽ) == (0, n) + @test norm(ũ * s̃ * ṽ) ≈ 0 + end end