diff --git a/Project.toml b/Project.toml index d50db71..8751572 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.3.3" +version = "0.3.4" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/docs/src/reference.md b/docs/src/reference.md index bc6be34..5724e2e 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -1,5 +1,5 @@ # Reference ```@autodocs -Modules = [TensorAlgebra] +Modules = [TensorAlgebra, TensorAlgebra.MatrixAlgebra] ``` diff --git a/src/MatrixAlgebra.jl b/src/MatrixAlgebra.jl index a3796b9..d99d5e1 100644 --- a/src/MatrixAlgebra.jl +++ b/src/MatrixAlgebra.jl @@ -17,7 +17,8 @@ export eigen, svd, svd!, svdvals, - svdvals! + svdvals!, + truncerr using LinearAlgebra: LinearAlgebra using MatrixAlgebraKit @@ -133,4 +134,43 @@ for (factorize, orth_f) in ((:factorize, :(MatrixAlgebra.orth)), (:factorize!, : end end +using MatrixAlgebraKit: MatrixAlgebraKit, TruncationStrategy + +struct TruncationError{T<:Real} <: TruncationStrategy + atol::T + rtol::T + p::Int +end + +""" + truncerr(; atol::Real=0, rtol::Real=0, p::Int=2) + +Create a truncation strategy for truncating such that the error in the factorization +is smaller than `max(atol, rtol * norm)`, where the error is determined using the `p`-norm. +""" +function truncerr(; atol::Real=0, rtol::Real=0, p::Int=2) + return TruncationError(promote(atol, rtol)..., p) +end + +function MatrixAlgebraKit.findtruncated(values::AbstractVector, strategy::TruncationError) + Base.require_one_based_indexing(values) + issorted(values; rev=true) || error("Not sorted.") + # norm(values, p) ^ p + normᵖ = sum(Base.Fix2(^, strategy.p) ∘ abs, values) + ϵᵖ = max(strategy.atol ^ strategy.p, strategy.rtol ^ strategy.p * normᵖ) + if ϵᵖ ≥ normᵖ + return Base.OneTo(0) + end + truncerrᵖ = zero(real(eltype(values))) + rank = length(values) + for i in reverse(eachindex(values)) + truncerrᵖ += abs(values[i]) ^ strategy.p + if truncerrᵖ ≥ ϵᵖ + rank = i + break + end + end + return Base.OneTo(rank) +end + end diff --git a/test/test_exports.jl b/test/test_exports.jl index 0c7b00b..321164f 100644 --- a/test/test_exports.jl +++ b/test/test_exports.jl @@ -45,6 +45,7 @@ using TensorAlgebra: TensorAlgebra :svd!, :svdvals, :svdvals!, + :truncerr, ] @test issetequal(names(TensorAlgebra.MatrixAlgebra), exports) end diff --git a/test/test_matrixalgebra.jl b/test/test_matrixalgebra.jl index 7ee3598..012ea10 100644 --- a/test/test_matrixalgebra.jl +++ b/test/test_matrixalgebra.jl @@ -1,149 +1,307 @@ -using LinearAlgebra: I, diag, isposdef +using LinearAlgebra: Diagonal, I, diag, isposdef, norm +using MatrixAlgebraKit: qr_compact, svd_trunc +using StableRNGs: StableRNG +using TensorAlgebra.MatrixAlgebra: MatrixAlgebra, truncerr using Test: @test, @testset -using TensorAlgebra.MatrixAlgebra: MatrixAlgebra - elts = (Float32, Float64, ComplexF32, ComplexF64) @testset "TensorAlgebra.MatrixAlgebra (elt=$elt)" for elt in elts - A = randn(elt, 3, 2) - for positive in (false, true) - for (Q, R) in (MatrixAlgebra.qr(A; positive), MatrixAlgebra.qr(A; full=false, positive)) + @testset "Factorizations" begin + rng = StableRNG(123) + A = randn(rng, elt, 3, 2) + for positive in (false, true) + for (Q, R) in + (MatrixAlgebra.qr(A; positive), MatrixAlgebra.qr(A; full=false, positive)) + @test A ≈ Q * R + @test size(Q) == size(A) + @test size(R) == (size(A, 2), size(A, 2)) + @test Q' * Q ≈ I + @test Q * Q' ≉ I + if positive + @test all(≥(0), real(diag(R))) + @test all(≈(0), imag(diag(R))) + end + end + end + + A = randn(elt, 3, 2) + for positive in (false, true) + Q, R = MatrixAlgebra.qr(A; full=true, positive) @test A ≈ Q * R - @test size(Q) == size(A) - @test size(R) == (size(A, 2), size(A, 2)) + @test size(Q) == (size(A, 1), size(A, 1)) + @test size(R) == size(A) @test Q' * Q ≈ I - @test Q * Q' ≉ I + @test Q * Q' ≈ I if positive @test all(≥(0), real(diag(R))) @test all(≈(0), imag(diag(R))) end end - end - A = randn(elt, 3, 2) - for positive in (false, true) - Q, R = MatrixAlgebra.qr(A; full=true, positive) - @test A ≈ Q * R - @test size(Q) == (size(A, 1), size(A, 1)) - @test size(R) == size(A) - @test Q' * Q ≈ I - @test Q * Q' ≈ I - if positive - @test all(≥(0), real(diag(R))) - @test all(≈(0), imag(diag(R))) + A = randn(elt, 2, 3) + for positive in (false, true) + for (L, Q) in + (MatrixAlgebra.lq(A; positive), MatrixAlgebra.lq(A; full=false, positive)) + @test A ≈ L * Q + @test size(L) == (size(A, 1), size(A, 1)) + @test size(Q) == size(A) + @test Q * Q' ≈ I + @test Q' * Q ≉ I + if positive + @test all(≥(0), real(diag(L))) + @test all(≈(0), imag(diag(L))) + end + end end - end - A = randn(elt, 2, 3) - for positive in (false, true) - for (L, Q) in (MatrixAlgebra.lq(A; positive), MatrixAlgebra.lq(A; full=false, positive)) + A = randn(elt, 3, 2) + for positive in (false, true) + L, Q = MatrixAlgebra.lq(A; full=true, positive) @test A ≈ L * Q - @test size(L) == (size(A, 1), size(A, 1)) - @test size(Q) == size(A) + @test size(L) == size(A) + @test size(Q) == (size(A, 2), size(A, 2)) @test Q * Q' ≈ I - @test Q' * Q ≉ I + @test Q' * Q ≈ I if positive @test all(≥(0), real(diag(L))) @test all(≈(0), imag(diag(L))) end end - end - A = randn(elt, 3, 2) - for positive in (false, true) - L, Q = MatrixAlgebra.lq(A; full=true, positive) - @test A ≈ L * Q - @test size(L) == size(A) - @test size(Q) == (size(A, 2), size(A, 2)) - @test Q * Q' ≈ I - @test Q' * Q ≈ I - if positive - @test all(≥(0), real(diag(L))) - @test all(≈(0), imag(diag(L))) + A = randn(elt, 3, 2) + for (W, C) in (MatrixAlgebra.orth(A), MatrixAlgebra.orth(A; side=:left)) + @test A ≈ W * C + @test size(W) == size(A) + @test size(C) == (size(A, 2), size(A, 2)) + @test W' * W ≈ I + @test W * W' ≉ I end - end - A = randn(elt, 3, 2) - for (W, C) in (MatrixAlgebra.orth(A), MatrixAlgebra.orth(A; side=:left)) - @test A ≈ W * C + A = randn(elt, 2, 3) + C, W = MatrixAlgebra.orth(A; side=:right) + @test A ≈ C * W + @test size(C) == (size(A, 1), size(A, 1)) @test size(W) == size(A) - @test size(C) == (size(A, 2), size(A, 2)) - @test W' * W ≈ I - @test W * W' ≉ I - end + @test W * W' ≈ I + @test W' * W ≉ I - A = randn(elt, 2, 3) - C, W = MatrixAlgebra.orth(A; side=:right) - @test A ≈ C * W - @test size(C) == (size(A, 1), size(A, 1)) - @test size(W) == size(A) - @test W * W' ≈ I - @test W' * W ≉ I - - A = randn(elt, 3, 2) - for (W, P) in (MatrixAlgebra.polar(A), MatrixAlgebra.polar(A; side=:left)) - @test A ≈ W * P + A = randn(elt, 3, 2) + for (W, P) in (MatrixAlgebra.polar(A), MatrixAlgebra.polar(A; side=:left)) + @test A ≈ W * P + @test size(W) == size(A) + @test size(P) == (size(A, 2), size(A, 2)) + @test W' * W ≈ I + @test W * W' ≉ I + @test isposdef(P) + end + + A = randn(elt, 2, 3) + P, W = MatrixAlgebra.polar(A; side=:right) + @test A ≈ P * W + @test size(P) == (size(A, 1), size(A, 1)) @test size(W) == size(A) - @test size(P) == (size(A, 2), size(A, 2)) - @test W' * W ≈ I - @test W * W' ≉ I + @test W * W' ≈ I + @test W' * W ≉ I @test isposdef(P) - end - A = randn(elt, 2, 3) - P, W = MatrixAlgebra.polar(A; side=:right) - @test A ≈ P * W - @test size(P) == (size(A, 1), size(A, 1)) - @test size(W) == size(A) - @test W * W' ≈ I - @test W' * W ≉ I - @test isposdef(P) - - A = randn(elt, 3, 2) - for (W, C) in (MatrixAlgebra.factorize(A), MatrixAlgebra.factorize(A; orth=:left)) - @test A ≈ W * C + A = randn(elt, 3, 2) + for (W, C) in (MatrixAlgebra.factorize(A), MatrixAlgebra.factorize(A; orth=:left)) + @test A ≈ W * C + @test size(W) == size(A) + @test size(C) == (size(A, 2), size(A, 2)) + @test W' * W ≈ I + @test W * W' ≉ I + end + + A = randn(elt, 2, 3) + C, W = MatrixAlgebra.factorize(A; orth=:right) + @test A ≈ C * W + @test size(C) == (size(A, 1), size(A, 1)) @test size(W) == size(A) - @test size(C) == (size(A, 2), size(A, 2)) - @test W' * W ≈ I - @test W * W' ≉ I - end + @test W * W' ≈ I + @test W' * W ≉ I + + A = randn(elt, 3, 3) + D, V = MatrixAlgebra.eigen(A) + @test A * V ≈ V * D + @test MatrixAlgebra.eigvals(A) ≈ diag(D) - A = randn(elt, 2, 3) - C, W = MatrixAlgebra.factorize(A; orth=:right) - @test A ≈ C * W - @test size(C) == (size(A, 1), size(A, 1)) - @test size(W) == size(A) - @test W * W' ≈ I - @test W' * W ≉ I - - A = randn(elt, 3, 3) - D, V = MatrixAlgebra.eigen(A) - @test A * V ≈ V * D - @test MatrixAlgebra.eigvals(A) ≈ diag(D) - - A = randn(elt, 3, 2) - for (U, S, V) in (MatrixAlgebra.svd(A), MatrixAlgebra.svd(A; full=false)) + A = randn(elt, 3, 2) + for (U, S, V) in (MatrixAlgebra.svd(A), MatrixAlgebra.svd(A; full=false)) + @test A ≈ U * S * V + @test size(U) == size(A) + @test size(S) == (size(A, 2), size(A, 2)) + @test size(V) == (size(A, 2), size(A, 2)) + @test U' * U ≈ I + @test U * U' ≉ I + @test V * V' ≈ I + @test V' * V ≈ I + @test MatrixAlgebra.svdvals(A) ≈ diag(S) + end + + A = randn(elt, 3, 2) + U, S, V = MatrixAlgebra.svd(A; full=true) @test A ≈ U * S * V - @test size(U) == size(A) - @test size(S) == (size(A, 2), size(A, 2)) + @test size(U) == (size(A, 1), size(A, 1)) + @test size(S) == size(A) @test size(V) == (size(A, 2), size(A, 2)) @test U' * U ≈ I - @test U * U' ≉ I + @test U * U' ≈ I @test V * V' ≈ I @test V' * V ≈ I @test MatrixAlgebra.svdvals(A) ≈ diag(S) end + @testset "Truncation" begin + s = Diagonal(real(elt)[1.2, 0.9, 0.3, 0.2, 0.01]) + n = length(diag(s)) + rng = StableRNG(123) + u, _ = qr_compact(randn(rng, elt, n, n); positive=true) + v, _ = qr_compact(randn(rng, elt, n, n); positive=true) + a = u * s * v + + # p = 2, relative = true + ũ, s̃, ṽ = svd_trunc( + a; trunc=truncerr(; rtol=norm([0.3, 0.2, 0.01]) / norm(diag(s)) + eps(real(elt))) + ) + @test size(ũ) == (n, 2) + @test size(s̃) == (2, 2) + @test size(ṽ) == (2, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.2, 0.01]) + ũ, s̃, ṽ = svd_trunc( + a; trunc=truncerr(; rtol=norm([0.3, 0.2, 0.01]) / norm(diag(s)) - 10eps(real(elt))) + ) + @test size(ũ) == (n, 3) + @test size(s̃) == (3, 3) + @test size(ṽ) == (3, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.2, 0.01]) + ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; rtol=0)) + @test size(ũ) == (n, n) + @test size(s̃) == (n, n) + @test size(ṽ) == (n, n) + @test ũ * s̃ * ṽ ≈ a + ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; rtol=1)) + @test size(ũ) == (n, 0) + @test size(s̃) == (0, 0) + @test size(ṽ) == (0, n) + @test norm(ũ * s̃ * ṽ) ≈ 0 - A = randn(elt, 3, 2) - U, S, V = MatrixAlgebra.svd(A; full=true) - @test A ≈ U * S * V - @test size(U) == (size(A, 1), size(A, 1)) - @test size(S) == size(A) - @test size(V) == (size(A, 2), size(A, 2)) - @test U' * U ≈ I - @test U * U' ≈ I - @test V * V' ≈ I - @test V' * V ≈ I - @test MatrixAlgebra.svdvals(A) ≈ diag(S) + # p = 2, relative = false + ũ, s̃, ṽ = svd_trunc( + a; trunc=truncerr(; atol=norm([0.3, 0.2, 0.01]) + eps(real(elt))) + ) + @test size(ũ) == (n, 2) + @test size(s̃) == (2, 2) + @test size(ṽ) == (2, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.2, 0.01]) + ũ, s̃, ṽ = svd_trunc( + a; trunc=truncerr(; atol=norm([0.3, 0.2, 0.01]) - 10eps(real(elt))) + ) + @test size(ũ) == (n, 3) + @test size(s̃) == (3, 3) + @test size(ṽ) == (3, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.2, 0.01]) + ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; atol=0)) + @test size(ũ) == (n, n) + @test size(s̃) == (n, n) + @test size(ṽ) == (n, n) + @test ũ * s̃ * ṽ ≈ a + ũ, s̃, ṽ = svd_trunc( + a; trunc=truncerr(; atol=(norm(diag(s)) * (one(real(elt)) + eps(real(elt))))) + ) + @test size(ũ) == (n, 0) + @test size(s̃) == (0, 0) + @test size(ṽ) == (0, n) + @test norm(ũ * s̃ * ṽ) ≈ 0 + + # p = 1, relative = true + ũ, s̃, ṽ = svd_trunc( + a; + trunc=truncerr(; + rtol=(norm([0.3, 0.2, 0.01], 1) / norm(diag(s), 1) + eps(real(elt))), p=1 + ), + ) + @test size(ũ) == (n, 2) + @test size(s̃) == (2, 2) + @test size(ṽ) == (2, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.2, 0.01]) + ũ, s̃, ṽ = svd_trunc( + a; + trunc=truncerr(; + rtol=(norm([0.3, 0.2, 0.01], 1) / norm(diag(s), 1) - eps(real(elt))), p=1 + ), + ) + @test size(ũ) == (n, 3) + @test size(s̃) == (3, 3) + @test size(ṽ) == (3, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.2, 0.01]) + ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; rtol=0, p=1)) + @test size(ũ) == (n, n) + @test size(s̃) == (n, n) + @test size(ṽ) == (n, n) + @test ũ * s̃ * ṽ ≈ a + ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; rtol=1, p=1)) + @test size(ũ) == (n, 0) + @test size(s̃) == (0, 0) + @test size(ṽ) == (0, n) + @test norm(ũ * s̃ * ṽ) ≈ 0 + + # p = 1, relative = false + ũ, s̃, ṽ = svd_trunc( + a; trunc=truncerr(; atol=(norm([0.3, 0.2, 0.01], 1) + 10eps(real(elt))), p=1) + ) + @test size(ũ) == (n, 2) + @test size(s̃) == (2, 2) + @test size(ṽ) == (2, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.2, 0.01]) + ũ, s̃, ṽ = svd_trunc( + a; trunc=truncerr(; atol=(norm([0.3, 0.2, 0.01], 1) - 10eps(real(elt))), p=1) + ) + @test size(ũ) == (n, 3) + @test size(s̃) == (3, 3) + @test size(ṽ) == (3, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.2, 0.01]) + ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; atol=0, p=1)) + @test size(ũ) == (n, n) + @test size(s̃) == (n, n) + @test size(ṽ) == (n, n) + @test ũ * s̃ * ṽ ≈ a + ũ, s̃, ṽ = svd_trunc( + a; + trunc=truncerr(; atol=(norm(diag(s), 1) * (one(real(elt)) + 10eps(real(elt)))), p=1), + ) + @test size(ũ) == (n, 0) + @test size(s̃) == (0, 0) + @test size(ṽ) == (0, n) + @test norm(ũ * s̃ * ṽ) ≈ 0 + + # Specifying both `atol` and `rtol`. + s = Diagonal(real(elt)[0.1, 0.01, 0.001]) + n = length(diag(s)) + rng = StableRNG(123) + u, _ = qr_compact(randn(rng, elt, n, n); positive=true) + v, _ = qr_compact(randn(rng, elt, n, n); positive=true) + a = u * s * v + + ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; rtol=0.002)) + @test size(ũ) == (n, n) + @test size(s̃) == (n, n) + @test size(ṽ) == (n, n) + @test ũ * s̃ * ṽ ≈ a + @test ũ * s̃ * ṽ ≈ a rtol = 0.002 + + ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; atol=0.002)) + @test size(ũ) == (n, 2) + @test size(s̃) == (2, 2) + @test size(ṽ) == (2, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.001]) + @test ũ * s̃ * ṽ ≈ a atol = 0.002 + + ũ, s̃, ṽ = svd_trunc(a; trunc=truncerr(; atol=0.002, rtol=0.002)) + @test size(ũ) == (n, 2) + @test size(s̃) == (2, 2) + @test size(ṽ) == (2, n) + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.001]) + @test ũ * s̃ * ṽ ≈ a atol = 0.002 rtol = 0.002 + end end