Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ext/TensorKitChainRulesCoreExt/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector
Sp = view(S, 1:p)

# rank
r = findlast(>=(tol), S)
r = count(>(tol), S)

# compute antihermitian part of projection of ΔU and ΔV onto U and V
# also already subtract this projection from ΔU and ΔV
Expand Down
4 changes: 2 additions & 2 deletions src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ export mul!, lmul!, rmul!, adjoint!, pinv, axpy!, axpby!
export leftorth, rightorth, leftnull, rightnull,
leftorth!, rightorth!, leftnull!, rightnull!,
tsvd!, tsvd, eigen, eigen!, eig, eig!, eigh, eigh!, exp, exp!,
isposdef, isposdef!, ishermitian, sylvester
isposdef, isposdef!, ishermitian, sylvester, rank, cond
export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition,
repartition!
export catdomain, catcodomain
Expand Down Expand Up @@ -119,7 +119,7 @@ using LinearAlgebra: norm, dot, normalize, normalize!, tr,
adjoint, adjoint!, transpose, transpose!,
lu, pinv, sylvester,
eigen, eigen!, svd, svd!,
isposdef, isposdef!, ishermitian,
isposdef, isposdef!, ishermitian, rank, cond,
Diagonal, Hermitian

using SparseArrays: SparseMatrixCSC, sparse, nzrange, rowvals, nonzeros
Expand Down
26 changes: 26 additions & 0 deletions src/tensors/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,32 @@
end
end

_default_rtol(t) = eps(real(float(scalartype(t)))) * min(dim(domain(t)), dim(codomain(t)))

function LinearAlgebra.rank(t::AbstractTensorMap; atol::Real=0,
rtol::Real=atol > 0 ? 0 : _default_rtol(t))
dim(t) == 0 && return 0
S = LinearAlgebra.svdvals(t)
tol = max(atol, rtol * maximum(first, values(S)))
return sum(cs -> dim(cs[1]) * count(>(tol), cs[2]), S)
end

function LinearAlgebra.cond(t::AbstractTensorMap, p::Real=2)
if p == 2
if dim(t) == 0
domain(t) == codomain(t) ||
throw(SpaceMismatch("`cond` requires domain and codomain to be the same"))
return zero(real(float(scalartype(t))))
end
S = LinearAlgebra.svdvals(t)
maxS = maximum(first, values(S))
minS = minimum(last, values(S))
return iszero(maxS) ? oftype(maxS, Inf) : (maxS / minS)
else
throw(ArgumentError("cond currently only defined for p=2"))

Check warning on line 296 in src/tensors/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/linalg.jl#L296

Added line #L296 was not covered by tests
end
end

# TensorMap trace
function LinearAlgebra.tr(t::AbstractTensorMap)
domain(t) == codomain(t) ||
Expand Down
25 changes: 25 additions & 0 deletions test/tensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,23 @@ for V in spacelist
@test b ≈ s′[c]
end
end
@testset "cond and rank" begin
t2 = permute(t, ((3, 4, 2), (1, 5)))
d1 = dim(codomain(t2))
d2 = dim(domain(t2))
@test rank(t2) == min(d1, d2)
M = leftnull(t2)
@test rank(M) == max(d1, d2) - min(d1, d2)
t3 = unitary(T, V1 ⊗ V2, V1 ⊗ V2)
@test cond(t3) ≈ one(real(T))
@test rank(t3) == dim(V1 ⊗ V2)
t4 = randn(T, V1 ⊗ V2, V1 ⊗ V2)
t4 = (t4 + t4') / 2
vals = LinearAlgebra.eigvals(t4)
λmax = maximum(s -> maximum(abs, s), values(vals))
λmin = minimum(s -> minimum(abs, s), values(vals))
@test cond(t4) ≈ λmax / λmin
end
end
@testset "empty tensor" begin
t = randn(T, V1 ⊗ V2, zero(V1))
Expand Down Expand Up @@ -586,6 +603,13 @@ for V in spacelist
@test U == t
@test dim(U) == dim(S) == dim(V)
end
@testset "cond and rank" begin
@test rank(t) == 0
W2 = zero(V1) * zero(V2)
t2 = rand(W2, W2)
@test rank(t2) == 0
@test cond(t2) == 0.0
end
end
t = rand(T, V1 ⊗ V1' ⊗ V2 ⊗ V2')
@testset "eig and isposdef" begin
Expand Down Expand Up @@ -615,6 +639,7 @@ for V in spacelist
@test V ≈ Ṽ
λ = minimum(minimum(real(LinearAlgebra.diag(b)))
for (c, b) in blocks(D))
@test cond(Ṽ) ≈ one(real(T))
@test isposdef(t2) == isposdef(λ)
@test isposdef(t2 - λ * one(t2) + 0.1 * one(t2))
@test !isposdef(t2 - λ * one(t2) - 0.1 * one(t2))
Expand Down
Loading