diff --git a/Project.toml b/Project.toml index 7a85171..a7c2676 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.4.1" +version = "0.4.2" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index 82859cd..5eec67f 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -17,8 +17,7 @@ export eigen, svd, svd!, svdvals, - svdvals!, - truncerr + svdvals! using LinearAlgebra: LinearAlgebra, norm using MatrixAlgebraKit @@ -143,43 +142,6 @@ end using MatrixAlgebraKit: MatrixAlgebraKit, TruncationStrategy -struct TruncationError{T<:Real} <: TruncationStrategy - atol::T - rtol::T - p::Int -end - -""" - truncerr(; atol::Real=0, rtol::Real=0, p::Int=2) - -Create a truncation strategy for truncating such that the error in the factorization -is smaller than `max(atol, rtol * norm)`, where the error is determined using the `p`-norm. -""" -function truncerr(; atol::Real=0, rtol::Real=0, p::Int=2) - return TruncationError(promote(atol, rtol)..., p) -end - -function MatrixAlgebraKit.findtruncated(values::AbstractVector, strategy::TruncationError) - Base.require_one_based_indexing(values) - issorted(values; rev=true) || error("Not sorted.") - # norm(values, p) ^ p - normᵖ = sum(Base.Fix2(^, strategy.p) ∘ abs, values) - ϵᵖ = max(strategy.atol ^ strategy.p, strategy.rtol ^ strategy.p * normᵖ) - if ϵᵖ ≥ normᵖ - return Base.OneTo(0) - end - truncerrᵖ = zero(real(eltype(values))) - rank = length(values) - for i in reverse(eachindex(values)) - truncerrᵖ += abs(values[i]) ^ strategy.p - if truncerrᵖ ≥ ϵᵖ - rank = i - break - end - end - return Base.OneTo(rank) -end - struct TruncationDegenerate{Strategy<:TruncationStrategy,T<:Real} <: TruncationStrategy strategy::Strategy atol::T diff --git a/test/Project.toml b/test/Project.toml index 656853a..07bf51e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -16,14 +16,14 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" [compat] Aqua = "0.8.9" BlockArrays = "1.6.1" -EllipsisNotation = "1.8.0" +EllipsisNotation = "1.8" LinearAlgebra = "<0.0.1, 1" MatrixAlgebraKit = "0.2, 0.3" Random = "1.10" SafeTestsets = "0.1" StableRNGs = "1.0.2" Suppressor = "0.2" -TensorAlgebra = "0.4.0" +TensorAlgebra = "0.4" TensorOperations = "5.1.4" Test = "1.10" TestExtras = "0.3.1" diff --git a/test/test_exports.jl b/test/test_exports.jl index 321164f..0c7b00b 100644 --- a/test/test_exports.jl +++ b/test/test_exports.jl @@ -45,7 +45,6 @@ using TensorAlgebra: TensorAlgebra :svd!, :svdvals, :svdvals!, - :truncerr, ] @test issetequal(names(TensorAlgebra.MatrixAlgebra), exports) end diff --git a/test/test_matrixalgebra.jl b/test/test_matrixalgebra.jl index 36ee71c..7feefa3 100644 --- a/test/test_matrixalgebra.jl +++ b/test/test_matrixalgebra.jl @@ -1,7 +1,7 @@ using LinearAlgebra: Diagonal, I, diag, isposdef, norm using MatrixAlgebraKit: qr_compact, svd_trunc, truncrank using StableRNGs: StableRNG -using TensorAlgebra.MatrixAlgebra: MatrixAlgebra, truncdegen, truncerr +using TensorAlgebra.MatrixAlgebra: MatrixAlgebra, truncdegen using Test: @test, @testset elts = (Float32, Float64, ComplexF32, ComplexF64) @@ -152,158 +152,6 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) @test V' * V ≈ I @test MatrixAlgebra.svdvals(A) ≈ diag(S) end - @testset "Truncation" begin - s = Diagonal(real(elt)[1.2, 0.9, 0.3, 0.2, 0.01]) - n = length(diag(s)) - rng = StableRNG(123) - u, _ = qr_compact(randn(rng, elt, n, n); positive=true) - v, _ = qr_compact(randn(rng, elt, n, n); positive=true) - a = u * s * v - - # p = 2, relative = true - ũ, s̃, ṽ = svd_trunc( - a; trunc=truncerr(; rtol=norm([0.3, 0.2, 0.01]) / norm(diag(s)) + 10eps(real(elt))) - ) - @test size(ũ) == (n, 2) - @test size(s̃) == (2, 2) - @test size(ṽ) == (2, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.2, 0.01]) - ũ, s̃, ṽ = svd_trunc( - a; trunc=truncerr(; rtol=norm([0.3, 0.2, 0.01]) / norm(diag(s)) - 10eps(real(elt))) - ) - @test size(ũ) == (n, 3) - @test size(s̃) == (3, 3) - @test size(ṽ) == (3, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.2, 0.01]) - ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; rtol=0)) - @test size(ũ) == (n, n) - @test size(s̃) == (n, n) - @test size(ṽ) == (n, n) - @test ũ * s̃ * ṽ ≈ a - ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; rtol=1)) - @test size(ũ) == (n, 0) - @test size(s̃) == (0, 0) - @test size(ṽ) == (0, n) - @test norm(ũ * s̃ * ṽ) ≈ 0 - - # p = 2, relative = false - ũ, s̃, ṽ = svd_trunc( - a; trunc=truncerr(; atol=norm([0.3, 0.2, 0.01]) + 10eps(real(elt))) - ) - @test size(ũ) == (n, 2) - @test size(s̃) == (2, 2) - @test size(ṽ) == (2, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.2, 0.01]) - ũ, s̃, ṽ = svd_trunc( - a; trunc=truncerr(; atol=norm([0.3, 0.2, 0.01]) - 10eps(real(elt))) - ) - @test size(ũ) == (n, 3) - @test size(s̃) == (3, 3) - @test size(ṽ) == (3, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.2, 0.01]) - ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; atol=0)) - @test size(ũ) == (n, n) - @test size(s̃) == (n, n) - @test size(ṽ) == (n, n) - @test ũ * s̃ * ṽ ≈ a - ũ, s̃, ṽ = svd_trunc( - a; trunc=truncerr(; atol=(norm(diag(s)) * (one(real(elt)) + 10eps(real(elt))))) - ) - @test size(ũ) == (n, 0) - @test size(s̃) == (0, 0) - @test size(ṽ) == (0, n) - @test norm(ũ * s̃ * ṽ) ≈ 0 - - # p = 1, relative = true - ũ, s̃, ṽ = svd_trunc( - a; - trunc=truncerr(; - rtol=(norm([0.3, 0.2, 0.01], 1) / norm(diag(s), 1) + 10eps(real(elt))), p=1 - ), - ) - @test size(ũ) == (n, 2) - @test size(s̃) == (2, 2) - @test size(ṽ) == (2, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.2, 0.01]) - ũ, s̃, ṽ = svd_trunc( - a; - trunc=truncerr(; - rtol=(norm([0.3, 0.2, 0.01], 1) / norm(diag(s), 1) - 10eps(real(elt))), p=1 - ), - ) - @test size(ũ) == (n, 3) - @test size(s̃) == (3, 3) - @test size(ṽ) == (3, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.2, 0.01]) - ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; rtol=0, p=1)) - @test size(ũ) == (n, n) - @test size(s̃) == (n, n) - @test size(ṽ) == (n, n) - @test ũ * s̃ * ṽ ≈ a - ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; rtol=1, p=1)) - @test size(ũ) == (n, 0) - @test size(s̃) == (0, 0) - @test size(ṽ) == (0, n) - @test norm(ũ * s̃ * ṽ) ≈ 0 - - # p = 1, relative = false - ũ, s̃, ṽ = svd_trunc( - a; trunc=truncerr(; atol=(norm([0.3, 0.2, 0.01], 1) + 10eps(real(elt))), p=1) - ) - @test size(ũ) == (n, 2) - @test size(s̃) == (2, 2) - @test size(ṽ) == (2, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.2, 0.01]) - ũ, s̃, ṽ = svd_trunc( - a; trunc=truncerr(; atol=(norm([0.3, 0.2, 0.01], 1) - 10eps(real(elt))), p=1) - ) - @test size(ũ) == (n, 3) - @test size(s̃) == (3, 3) - @test size(ṽ) == (3, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.2, 0.01]) - ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; atol=0, p=1)) - @test size(ũ) == (n, n) - @test size(s̃) == (n, n) - @test size(ṽ) == (n, n) - @test ũ * s̃ * ṽ ≈ a - ũ, s̃, ṽ = svd_trunc( - a; - trunc=truncerr(; atol=(norm(diag(s), 1) * (one(real(elt)) + 10eps(real(elt)))), p=1), - ) - @test size(ũ) == (n, 0) - @test size(s̃) == (0, 0) - @test size(ṽ) == (0, n) - @test norm(ũ * s̃ * ṽ) ≈ 0 - - # Specifying both `atol` and `rtol`. - s = Diagonal(real(elt)[0.1, 0.01, 0.001]) - n = length(diag(s)) - rng = StableRNG(123) - u, _ = qr_compact(randn(rng, elt, n, n); positive=true) - v, _ = qr_compact(randn(rng, elt, n, n); positive=true) - a = u * s * v - - ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; rtol=0.002)) - @test size(ũ) == (n, n) - @test size(s̃) == (n, n) - @test size(ṽ) == (n, n) - @test ũ * s̃ * ṽ ≈ a - @test ũ * s̃ * ṽ ≈ a rtol = 0.002 - - ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; atol=0.002)) - @test size(ũ) == (n, 2) - @test size(s̃) == (2, 2) - @test size(ṽ) == (2, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.001]) - @test ũ * s̃ * ṽ ≈ a atol = 0.002 - - ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; atol=0.002, rtol=0.002)) - @test size(ũ) == (n, 2) - @test size(s̃) == (2, 2) - @test size(ṽ) == (2, n) - @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.001]) - @test ũ * s̃ * ṽ ≈ a atol = 0.002 rtol = 0.002 - end @testset "Truncate degenerate" begin s = Diagonal(real(elt)[2.0, 0.32, 0.3, 0.29, 0.01, 0.01]) n = length(diag(s))