Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.4"
version = "0.3.5"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
54 changes: 54 additions & 0 deletions src/MatrixAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,58 @@
return Base.OneTo(rank)
end

struct TruncationDegenerate{Strategy<:TruncationStrategy,T<:Real} <: TruncationStrategy
strategy::Strategy
atol::T
rtol::T
end

"""
truncdegen(trunc::TruncationStrategy; atol::Real=0, rtol::Real=0)
Modify a truncation strategy so that if the truncation falls within
a degenerate subspace, the entire subspace gets truncated as well.
Adjacent values `v1` and `v2` in the spectrum are considered to be degenerate if
`≈(v1, v2; atol, rtol)`.
For now, this truncation strategy assumes the spectrum being truncated
has already been reverse sorted and the strategy being wrapped
outputs a contiguous subset of values including the largest one. It
also only truncates for now, so may not respect if a minimum dimension
was requested in the strategy being wrapped. These restrictions may
be lifted in the future or provided through a different truncation strategy.
"""
function truncdegen(strategy::TruncationStrategy; atol::Real=0, rtol::Real=0)
return TruncationDegenerate(strategy, promote(atol, rtol)...)

Check warning on line 198 in src/MatrixAlgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/MatrixAlgebra.jl#L197-L198

Added lines #L197 - L198 were not covered by tests
end

using MatrixAlgebraKit: findtruncated

function MatrixAlgebraKit.findtruncated(

Check warning on line 203 in src/MatrixAlgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/MatrixAlgebra.jl#L203

Added line #L203 was not covered by tests
values::AbstractVector, strategy::TruncationDegenerate
)
Base.require_one_based_indexing(values)
issorted(values; rev=true) || throw(ArgumentError("Values aren't reverse sorted."))
indices_collection = findtruncated(values, strategy.strategy)
indices = Base.OneTo(maximum(indices_collection))
indices_collection == indices ||

Check warning on line 210 in src/MatrixAlgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/MatrixAlgebra.jl#L206-L210

Added lines #L206 - L210 were not covered by tests
throw(ArgumentError("Truncation must be a contiguous range."))
if length(indices_collection) == length(values)

Check warning on line 212 in src/MatrixAlgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/MatrixAlgebra.jl#L212

Added line #L212 was not covered by tests
# No truncation occured.
return indices

Check warning on line 214 in src/MatrixAlgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/MatrixAlgebra.jl#L214

Added line #L214 was not covered by tests
end
# Value of the largest truncated value.
val = values[last(indices) + 1]
ind = last(indices)
for i in reverse(Base.OneTo(last(indices)))
if (values[i], val; atol=strategy.atol, rtol=strategy.rtol)
ind = i - 1
val = values[i]

Check warning on line 222 in src/MatrixAlgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/MatrixAlgebra.jl#L217-L222

Added lines #L217 - L222 were not covered by tests
else
break

Check warning on line 224 in src/MatrixAlgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/MatrixAlgebra.jl#L224

Added line #L224 was not covered by tests
end
end
return Base.OneTo(ind)

Check warning on line 227 in src/MatrixAlgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/MatrixAlgebra.jl#L226-L227

Added lines #L226 - L227 were not covered by tests
end

end
105 changes: 103 additions & 2 deletions test/test_matrixalgebra.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using LinearAlgebra: Diagonal, I, diag, isposdef, norm
using MatrixAlgebraKit: qr_compact, svd_trunc
using MatrixAlgebraKit: qr_compact, svd_trunc, truncrank
using StableRNGs: StableRNG
using TensorAlgebra.MatrixAlgebra: MatrixAlgebra, truncerr
using TensorAlgebra.MatrixAlgebra: MatrixAlgebra, truncdegen, truncerr
using Test: @test, @testset

elts = (Float32, Float64, ComplexF32, ComplexF64)
Expand Down Expand Up @@ -304,4 +304,105 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.001])
@test ũ * s̃ * ṽ ≈ a atol = 0.002 rtol = 0.002
end
@testset "Truncate degenerate" begin
s = Diagonal(real(elt)[2.0, 0.32, 0.3, 0.29, 0.01, 0.01])
n = length(diag(s))
rng = StableRNG(123)
u, _ = qr_compact(randn(rng, elt, n, n); positive=true)
v, _ = qr_compact(randn(rng, elt, n, n); positive=true)
a = u * s * v

ũ, s̃, ṽ = svd_trunc(a; trunc=truncdegen(truncrank(n); atol=0.1))
@test size(ũ) == (n, n)
@test size(s̃) == (n, n)
@test size(ṽ) == (n, n)
@test ũ * s̃ * ṽ ≈ a

for kwargs in (
(; atol=eps(real(elt))),
(; rtol=(√eps(real(elt)))),
(; atol=eps(real(elt)), rtol=(√eps(real(elt)))),
)
ũ, s̃, ṽ = svd_trunc(a; trunc=truncdegen(truncrank(5); kwargs...))
@test size(ũ) == (n, 4)
@test size(s̃) == (4, 4)
@test size(ṽ) == (4, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.01, 0.01])
end

for kwargs in (
(; atol=eps(real(elt))),
(; rtol=eps(real(elt))),
(; atol=eps(real(elt)), rtol=eps(real(elt))),
)
ũ, s̃, ṽ = svd_trunc(a; trunc=truncdegen(truncrank(4); kwargs...))
@test size(ũ) == (n, 4)
@test size(s̃) == (4, 4)
@test size(ṽ) == (4, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.01, 0.01])
end

trunc = truncdegen(truncrank(3); atol=0.01 - √eps(real(elt)))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 3)
@test size(s̃) == (3, 3)
@test size(ṽ) == (3, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); rtol=0.01/0.3 - √eps(real(elt)))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 3)
@test size(s̃) == (3, 3)
@test size(ṽ) == (3, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); atol=0.01 + √eps(real(elt)))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 2)
@test size(s̃) == (2, 2)
@test size(ṽ) == (2, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); rtol=0.01/0.3 + √eps(real(elt)))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 2)
@test size(s̃) == (2, 2)
@test size(ṽ) == (2, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); atol=0.02 - √eps(real(elt)))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 2)
@test size(s̃) == (2, 2)
@test size(ṽ) == (2, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); rtol=0.02/0.32 - √eps(real(elt)))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 2)
@test size(s̃) == (2, 2)
@test size(ṽ) == (2, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); atol=0.02 + √eps(real(elt)))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 1)
@test size(s̃) == (1, 1)
@test size(ṽ) == (1, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); rtol=0.02/0.32 + √eps(real(elt)))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 1)
@test size(s̃) == (1, 1)
@test size(ṽ) == (1, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01])

trunc = truncdegen(truncrank(3); atol=0.01, rtol=0.02/0.32 + √eps(real(elt)))
ũ, s̃, ṽ = svd_trunc(a; trunc)
@test size(ũ) == (n, 1)
@test size(s̃) == (1, 1)
@test size(ṽ) == (1, n)
@test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01])
end
end
Loading