Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/tensors/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ storagetype(::Type{<:DiagonalTensorMap{T,S,A}}) where {T,S,A<:DenseVector{T}} =

Construct a `DiagonalTensorMap` with uninitialized data.
"""
function DiagonalTensorMap{T}(::UndefInitializer, V::TensorMapSpace) where {T}
(numin(V) == numout(V) == 1 && domain(V) == codomain(V)) ||
throw(ArgumentError("DiagonalTensorMap requires a space with equal domain and codomain and 2 indices"))
return DiagonalTensorMap{T}(undef, domain(V))
end
function DiagonalTensorMap{T}(::UndefInitializer, V::ProductSpace) where {T}
length(V) == 1 ||
throw(ArgumentError("DiagonalTensorMap requires `numin(d) == numout(d) == 1`"))
return DiagonalTensorMap{T}(undef, only(V))
end
function DiagonalTensorMap{T}(::UndefInitializer, V::S) where {T,S<:IndexSpace}
return DiagonalTensorMap{T,S,Vector{T}}(undef, V)
end
Expand Down Expand Up @@ -265,6 +275,22 @@ function LinearAlgebra.mul!(dC::DiagonalTensorMap,
return dC
end

function LinearAlgebra.lmul!(D::DiagonalTensorMap, t::AbstractTensorMap)
domain(D) == codomain(t) || throw(SpaceMismatch())
for (c, b) in blocks(t)
lmul!(block(D, c), b)
end
return t
end

function LinearAlgebra.rmul!(t::AbstractTensorMap, D::DiagonalTensorMap)
codomain(D) == domain(t) || throw(SpaceMismatch())
for (c, b) in blocks(t)
rmul!(b, block(D, c))
end
return t
end

Base.inv(d::DiagonalTensorMap) = DiagonalTensorMap(inv.(d.data), d.domain)
function Base.:\(d1::DiagonalTensorMap, d2::DiagonalTensorMap)
d1.domain == d2.domain || throw(SpaceMismatch())
Expand Down Expand Up @@ -339,6 +365,17 @@ function _compute_svddata!(d::DiagonalTensorMap, alg::Union{SVD,SDD})
return SVDdata, dims
end

function LinearAlgebra.svdvals(d::DiagonalTensorMap)
return SectorDict(c => LinearAlgebra.svdvals(b) for (c, b) in blocks(d))
end
function LinearAlgebra.eigvals(d::DiagonalTensorMap)
return SectorDict(c => LinearAlgebra.eigvals(b) for (c, b) in blocks(d))
end

function LinearAlgebra.cond(d::DiagonalTensorMap, p::Real=2)
return LinearAlgebra.cond(Diagonal(d.data), p)
end

# matrix functions
for f in
(:exp, :cos, :sin, :tan, :cot, :cosh, :sinh, :tanh, :coth, :atan, :acot, :asinh, :sqrt,
Expand Down
30 changes: 30 additions & 0 deletions test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,16 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3),
@testset "DiagonalTensor with domain $V" for V in diagspacelist
@timedtestset "Basic properties and algebra" begin
for T in (Float32, Float64, ComplexF32, ComplexF64, BigFloat)
# constructors
t = @constinferred DiagonalTensorMap{T}(undef, V)
t = @constinferred DiagonalTensorMap(rand(T, reduceddim(V)), V)
t2 = @constinferred DiagonalTensorMap{T}(undef, space(t))
@test space(t2) == space(t)
@test_throws ArgumentError DiagonalTensorMap{T}(undef, V^2 V)
t2 = @constinferred DiagonalTensorMap{T}(undef, domain(t))
@test space(t2) == space(t)
@test_throws ArgumentError DiagonalTensorMap{T}(undef, V^2)
# properties
@test @constinferred(hash(t)) == hash(deepcopy(t))
@test scalartype(t) == T
@test codomain(t) == ProductSpace(V)
Expand Down Expand Up @@ -135,6 +143,16 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3),
@test u / t1 u / TensorMap(t1)
@test t1 * u' TensorMap(t1) * u'
@test t1 \ u' TensorMap(t1) \ u'

t3 = rand(Float64, V V^2)
t4 = rand(ComplexF64, V V^2)
@test t1 * t3 lmul!(t1, copy(t3))
@test t2 * t4 lmul!(t2, copy(t4))

t3 = rand(Float64, V^2 V)
t4 = rand(ComplexF64, V^2 V)
@test t3 * t1 rmul!(copy(t3), t1)
@test t4 * t2 rmul!(copy(t4), t2)
end
@timedtestset "Tensor contraction" begin
d = DiagonalTensorMap(rand(ComplexF64, reduceddim(V)), V)
Expand Down Expand Up @@ -175,6 +193,12 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3),
VdV2 = V2' * V2
@test VdV2 one(VdV2)
@test t2 * V2 V2 * D2

@test rank(D) rank(t)
@test cond(D) cond(t)
@test all(((s, t),) -> isapprox(s, t),
zip(values(LinearAlgebra.eigvals(D)),
values(LinearAlgebra.eigvals(t))))
end
@testset "leftorth with $alg" for alg in (TensorKit.QR(), TensorKit.QL())
Q, R = @constinferred leftorth(t; alg=alg)
Expand All @@ -201,6 +225,12 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3),
VdV = Vᴴ * Vᴴ'
@test VdV one(VdV)
@test U * S * Vᴴ t

@test rank(S) rank(t)
@test cond(S) cond(t)
@test all(((s, t),) -> isapprox(s, t),
zip(values(LinearAlgebra.svdvals(S)),
values(LinearAlgebra.svdvals(t))))
end
end
end
Expand Down
Loading