Skip to content

Commit 58cd26e

Browse files
committed
start implementation of DiagonalTensorMap
1 parent f218222 commit 58cd26e

File tree

3 files changed

+117
-2
lines changed

3 files changed

+117
-2
lines changed

src/TensorKit.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ export Vect, Rep # space constructors
2525
export CompositeSpace, ProductSpace # composite spaces
2626
export FusionTree
2727
export IndexSpace, HomSpace, TensorSpace, TensorMapSpace
28-
export AbstractTensorMap, AbstractTensor, TensorMap, Tensor, BraidingTensor # tensors and tensor properties
28+
export AbstractTensorMap, AbstractTensor, TensorMap, Tensor # tensors and tensor properties
29+
export DiagonalTensorMap, BraidingTensor
2930
export TruncationScheme
3031
export SpaceMismatch, SectorMismatch, IndexError # error types
3132

@@ -183,6 +184,7 @@ include("spaces/vectorspaces.jl")
183184
include("tensors/abstracttensor.jl")
184185
# include("tensors/tensortreeiterator.jl")
185186
include("tensors/tensor.jl")
187+
include("tensors/diagtensor.jl")
186188
include("tensors/adjoint.jl")
187189
include("tensors/linalg.jl")
188190
include("tensors/vectorinterface.jl")

src/spaces/vectorspaces.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ field(P::Type{<:CompositeSpace}) = field(spacetype(P))
256256
sectortype(P::Type{<:CompositeSpace}) = sectortype(spacetype(P))
257257

258258
# make ElementarySpace instances behave similar to ProductSpace instances
259-
blocksectors(V::ElementarySpace) = sectors(V)
259+
blocksectors(V::ElementarySpace) = collect(sectors(V))
260260
blockdim(V::ElementarySpace, c::Sector) = dim(V, c)
261261

262262
# Specific realizations of ElementarySpace types

src/tensors/diagtensor.jl

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# DiagonalTensorMap
2+
#==========================================================#
3+
struct DiagonalTensorMap{T,S<:IndexSpace,A<:DenseVector{T}} <: AbstractTensorMap{T,S,1,1}
4+
data::A
5+
domain::S # equals codomain
6+
7+
# uninitialized constructors
8+
function DiagonalTensorMap{T,S,A}(::UndefInitializer,
9+
dom::S) where {T,S<:IndexSpace,A<:DenseVector{T}}
10+
data = A(undef, reduceddim(dom))
11+
return DiagonalTensorMap{T,S,A}(data, dom)
12+
end
13+
# constructors from data
14+
function DiagonalTensorMap{T,S,A}(data::A,
15+
dom::S) where {T,S<:IndexSpace,A<:DenseVector{T}}
16+
T field(S) || @warn("scalartype(data) = $T ⊈ $(field(S)))", maxlog = 1)
17+
return DiagonalTensorMap{T,S,A}(data, dom)
18+
end
19+
end
20+
reduceddim(V::IndexSpace) = sum(c -> dim(V, c), sectors(V); init=0)
21+
22+
# Basic methods for characterising a tensor:
23+
#--------------------------------------------
24+
space(t::DiagonalTensorMap) = t.domain t.domain
25+
26+
"""
27+
storagetype(::Union{T,Type{T}}) where {T<:TensorMap} -> Type{A<:DenseVector}
28+
29+
Return the type of the storage `A` of the tensor map.
30+
"""
31+
storagetype(::Type{<:DiagonalTensorMap{T,S,A}}) where {T,S,A<:DenseVector{T}} = A
32+
33+
# DiagonalTensorMap constructors
34+
#--------------------------------
35+
# undef constructors
36+
"""
37+
DiagonalTensorMap{T}(undef, domain::S) where {T,S<:IndexSpace}
38+
# expert mode: select storage type `A`
39+
DiagonalTensorMap{T,S,A}(undef, domain::S) where {T,S<:IndexSpace,A<:DenseVector{T}}
40+
41+
Construct a `DiagonalTensorMap` with uninitialized data.
42+
"""
43+
function DiagonalTensorMap{T}(::UndefInitializer, V::S) where {T,S<:IndexSpace}
44+
return DiagonalTensorMap{T,S,Vector{T}}(undef, V)
45+
end
46+
47+
function DiagonalTensorMap{T}(data::A, V::S) where {T,S<:IndexSpace,A<:DenseVector{T}}
48+
length(data) == reduceddim(V) ||
49+
throw(DimensionMismatch("length(data) = $(length(data)) is not compatible with the space $V"))
50+
return DiagonalTensorMap{T,S,A}(data, V)
51+
end
52+
53+
function DiagonalTensorMap(data::DenseVector{T}, V::IndexSpace) where {T}
54+
return DiagonalTensorMap{T}(data, V)
55+
end
56+
57+
# TODO: more constructors needed?
58+
59+
# Special case adjoint:
60+
#-----------------------
61+
Base.adjoint(t::DiagonalTensorMap{<:Real}) = t
62+
Base.adjoint(t::DiagonalTensorMap{<:Complex}) = DiagonalTensorMap(conj(t.data), t.domain)
63+
64+
# Efficient copy constructors
65+
#-----------------------------
66+
Base.copy(t::DiagonalTensorMap) = typeof(t)(copy(t.data), t.domain)
67+
68+
function Base.complex(t::DiagonalTensorMap)
69+
if scalartype(t) <: Complex
70+
return t
71+
else
72+
return DiagonalTensorMap(complex(t.data), t.domain)
73+
end
74+
end
75+
76+
# Getting and setting the data at the block level
77+
#-------------------------------------------------
78+
blocksectors(t::DiagonalTensorMap) = blocksectors(t.domain)
79+
80+
function block(t::DiagonalTensorMap, s::Sector)
81+
sectortype(t) == typeof(s) || throw(SectorMismatch())
82+
offset = 0
83+
for c in sectors(t)
84+
if c < s
85+
offset += dim(t, c)
86+
elseif c == s
87+
r = offset .+ (1:dim(t, c))
88+
return Diagonal(view(t.data, r))
89+
else # s not in sectors(t)
90+
return Diagonal(view(t.data, 1:0))
91+
end
92+
end
93+
end
94+
95+
# TODO: is relying on generic AbstractTensorMap blocks sufficient?
96+
97+
# Indexing and getting and setting the data at the subblock level
98+
#-----------------------------------------------------------------
99+
@inline function Base.getindex(t::DiagonalTensorMap,
100+
f₁::FusionTree{I,1},
101+
f₂::FusionTree{I,1}) where {I<:Sector}
102+
s = f₁.uncoupled[1]
103+
s == f₁.uncoulped == f₂.uncoupled[1] == f₂.uncoupled || throw(SectorMismatch())
104+
return block(t, s)
105+
# TODO: do we want a StridedView here? Then we need to allocate a new matrix.
106+
end
107+
108+
function Base.setindex!(t::TensorMap,
109+
v,
110+
f₁::FusionTree{I,1},
111+
f₂::FusionTree{I,1}) where {I<:Sector}
112+
return copy!(getindex(t, f₁, f₂), v)
113+
end

0 commit comments

Comments
 (0)