Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ end
TensorKit.jl is a package that provides types and methods to represent and manipulate
tensors with symmetries. The emphasis is on the structure and functionality needed to build
tensor network algorithms for the simulation of quantum many-body systems. Such tensors are
typically invariant under a symmetry group which acts via specific representions on each of
typically invariant under a symmetry group which acts via specific representations on each of
the indices of the tensor. TensorKit.jl provides the functionality for constructing such
tensors and performing typical operations such as tensor contractions and decompositions,
thereby preserving the symmetries and exploiting them for optimal performance.
Expand Down
22 changes: 22 additions & 0 deletions src/tensors/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@ function DiagonalTensorMap(data::DenseVector{T}, V::IndexSpace) where {T}
return DiagonalTensorMap{T}(data, V)
end

function DiagonalTensorMap(t::AbstractTensorMap{T,S,1,1}) where {T,S}
isa(t, DiagonalTensorMap) && return t
domain(t) == codomain(t) ||
throw(SpaceMismatch("DiagonalTensorMap requires equal domain and codomain"))
A = storagetype(t)
d = DiagonalTensorMap{T,S,A}(undef, space(t, 1))
for (c, b) in blocks(d)
bt = block(t, c)
# TODO: rewrite in terms of `diagview` from MatrixAlgebraKit.jl
copy!(b.diag, view(bt, LinearAlgebra.diagind(bt)))
end
return t
end

# TODO: more constructors needed?

# Special case adjoint:
Expand Down Expand Up @@ -80,6 +94,14 @@ end
function Base.convert(D::Type{<:DiagonalTensorMap}, d::DiagonalTensorMap)
return DiagonalTensorMap(convert(storagetype(D), d.data), d.domain)
end
function Base.convert(::Type{DiagonalTensorMap}, t::AbstractTensorMap)
all(LinearAlgebra.isdiag ∘ last, blocks(t)) ||
throw(ArgumentError("DiagonalTensorMap requires input tensor that is diagonal"))
return DiagonalTensorMap(t)
end
function Base.convert(::Type{DiagonalTensorMap}, d::Dict{Symbol,Any})
return convert(DiagonalTensorMap, convert(TensorMap, d))
end

# Complex, real and imaginary parts
#-----------------------------------
Expand Down
5 changes: 5 additions & 0 deletions test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,14 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3),
@timedtestset "Tensor conversion" begin
t = @constinferred DiagonalTensorMap(undef, V)
rand!(t.data)
# element type conversion
tc = complex(t)
@test convert(typeof(tc), t) == tc
@test typeof(convert(typeof(tc), t)) == typeof(tc)
# to and from generic TensorMap
td = DiagonalTensorMap(TensorMap(t))
@test t == td
@test typeof(td) == typeof(t)
end
I = sectortype(V)
if BraidingStyle(I) isa SymmetricBraiding
Expand Down
Loading