diff --git a/README.md b/README.md index baf9a8e1f..835982c6a 100644 --- a/README.md +++ b/README.md @@ -136,7 +136,7 @@ end TensorKit.jl is a package that provides types and methods to represent and manipulate tensors with symmetries. The emphasis is on the structure and functionality needed to build tensor network algorithms for the simulation of quantum many-body systems. Such tensors are -typically invariant under a symmetry group which acts via specific representions on each of +typically invariant under a symmetry group which acts via specific representations on each of the indices of the tensor. TensorKit.jl provides the functionality for constructing such tensors and performing typical operations such as tensor contractions and decompositions, thereby preserving the symmetries and exploiting them for optimal performance. diff --git a/src/tensors/diagonal.jl b/src/tensors/diagonal.jl index cce5624ea..a13918d03 100644 --- a/src/tensors/diagonal.jl +++ b/src/tensors/diagonal.jl @@ -52,6 +52,20 @@ function DiagonalTensorMap(data::DenseVector{T}, V::IndexSpace) where {T} return DiagonalTensorMap{T}(data, V) end +function DiagonalTensorMap(t::AbstractTensorMap{T,S,1,1}) where {T,S} + isa(t, DiagonalTensorMap) && return t + domain(t) == codomain(t) || + throw(SpaceMismatch("DiagonalTensorMap requires equal domain and codomain")) + A = storagetype(t) + d = DiagonalTensorMap{T,S,A}(undef, space(t, 1)) + for (c, b) in blocks(d) + bt = block(t, c) + # TODO: rewrite in terms of `diagview` from MatrixAlgebraKit.jl + copy!(b.diag, view(bt, LinearAlgebra.diagind(bt))) + end + return d +end + # TODO: more constructors needed? # Special case adjoint: @@ -73,12 +87,17 @@ end TensorMap(d::DiagonalTensorMap) = copy!(similar(d), d) Base.convert(::Type{TensorMap}, d::DiagonalTensorMap) = TensorMap(d) -function Base.convert(::Type{DiagonalTensorMap{T,S,A}}, - d::DiagonalTensorMap{T,S,A}) where {T,S,A} - return d -end function Base.convert(D::Type{<:DiagonalTensorMap}, d::DiagonalTensorMap) - return DiagonalTensorMap(convert(storagetype(D), d.data), d.domain) + return (d isa D) ? d : DiagonalTensorMap(convert(storagetype(D), d.data), d.domain) +end +Base.convert(::Type{DiagonalTensorMap}, t::DiagonalTensorMap) = t +function Base.convert(::Type{DiagonalTensorMap}, t::AbstractTensorMap) + LinearAlgebra.isdiag(t) || + throw(ArgumentError("DiagonalTensorMap requires input tensor that is diagonal")) + return DiagonalTensorMap(t) +end +function Base.convert(::Type{DiagonalTensorMap}, d::Dict{Symbol,Any}) + return convert(DiagonalTensorMap, convert(TensorMap, d)) end # Complex, real and imaginary parts diff --git a/test/diagonal.jl b/test/diagonal.jl index 83cdee10e..e4b3f6225 100644 --- a/test/diagonal.jl +++ b/test/diagonal.jl @@ -89,9 +89,14 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3), @timedtestset "Tensor conversion" begin t = @constinferred DiagonalTensorMap(undef, V) rand!(t.data) + # element type conversion tc = complex(t) @test convert(typeof(tc), t) == tc @test typeof(convert(typeof(tc), t)) == typeof(tc) + # to and from generic TensorMap + td = DiagonalTensorMap(TensorMap(t)) + @test t == td + @test typeof(td) == typeof(t) end I = sectortype(V) if BraidingStyle(I) isa SymmetricBraiding