Skip to content

Commit 1c723ea

Browse files
committed
further progress on DiagonalTensorMap [skip ci]
1 parent 41e3cbb commit 1c723ea

File tree

4 files changed

+171
-29
lines changed

4 files changed

+171
-29
lines changed

src/spaces/homspace.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ Return the `HomSpace` obtained by permuting the indices of the domain and codoma
129129
according to the permutation `p₁` and `p₂` respectively.
130130
"""
131131
function permute(W::HomSpace{S}, (p₁, p₂)::Index2Tuple{N₁,N₂}) where {S,N₁,N₂}
132+
p = (p₁..., p₂...)
133+
TupleTools.isperm(p) && length(p) == numind(W) ||
134+
throw(ArgumentError("$((p₁, p₂)) is not a valid permutation for $(W)"))
132135
cod = ProductSpace{S,N₁}(map(n -> W[n], p₁))
133136
dom = ProductSpace{S,N₂}(map(n -> dual(W[n]), p₂))
134137
return cod dom

src/tensors/diagtensor.jl

Lines changed: 150 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ struct DiagonalTensorMap{T,S<:IndexSpace,A<:DenseVector{T}} <: AbstractTensorMap
1414
function DiagonalTensorMap{T,S,A}(data::A,
1515
dom::S) where {T,S<:IndexSpace,A<:DenseVector{T}}
1616
T field(S) || @warn("scalartype(data) = $T ⊈ $(field(S)))", maxlog = 1)
17-
return DiagonalTensorMap{T,S,A}(data, dom)
17+
return new{T,S,A}(data, dom)
1818
end
1919
end
2020
reduceddim(V::IndexSpace) = sum(c -> dim(V, c), sectors(V); init=0)
2121

2222
# Basic methods for characterising a tensor:
2323
#--------------------------------------------
24-
space(t::DiagonalTensorMap) = t.domain t.domain
24+
space(d::DiagonalTensorMap) = d.domain d.domain
2525

2626
"""
2727
storagetype(::Union{T,Type{T}}) where {T<:TensorMap} -> Type{A<:DenseVector}
@@ -58,36 +58,49 @@ end
5858

5959
# Special case adjoint:
6060
#-----------------------
61-
Base.adjoint(t::DiagonalTensorMap{<:Real}) = t
62-
Base.adjoint(t::DiagonalTensorMap{<:Complex}) = DiagonalTensorMap(conj(t.data), t.domain)
61+
Base.adjoint(d::DiagonalTensorMap{<:Real}) = d
62+
Base.adjoint(d::DiagonalTensorMap{<:Complex}) = DiagonalTensorMap(conj(d.data), d.domain)
6363

64-
# Efficient copy constructors
65-
#-----------------------------
66-
Base.copy(t::DiagonalTensorMap) = typeof(t)(copy(t.data), t.domain)
64+
# Efficient copy constructors and TensorMap converters
65+
#-----------------------------------------------------
66+
Base.copy(d::DiagonalTensorMap) = typeof(d)(copy(d.data), d.domain)
6767

68-
function Base.complex(t::DiagonalTensorMap)
69-
if scalartype(t) <: Complex
70-
return t
71-
else
72-
return DiagonalTensorMap(complex(t.data), t.domain)
68+
function Base.copy!(t::AbstractTensorMap, d::DiagonalTensorMap)
69+
space(t) == space(d) || throw(SpaceMismatch())
70+
for (c, b) in blocks(d)
71+
copy!(block(t, c), b)
72+
end
73+
return t
74+
end
75+
TensorMap(d::DiagonalTensorMap) = copy!(similar(d), d)
76+
Base.convert(::Type{TensorMap}, d::DiagonalTensorMap) = TensorMap(d)
77+
78+
# Complex, real and imaginary parts
79+
#-----------------------------------
80+
for f in (:real, :imag, :complex)
81+
@eval begin
82+
function Base.$f(d::DiagonalTensorMap)
83+
return DiagonalTensorMap($f(d.data), d.domain)
84+
end
7385
end
7486
end
7587

7688
# Getting and setting the data at the block level
7789
#-------------------------------------------------
78-
blocksectors(t::DiagonalTensorMap) = blocksectors(t.domain)
90+
blocksectors(d::DiagonalTensorMap) = blocksectors(d.domain)
7991

80-
function block(t::DiagonalTensorMap, s::Sector)
81-
sectortype(t) == typeof(s) || throw(SectorMismatch())
92+
function block(d::DiagonalTensorMap, s::Sector)
93+
sectortype(d) == typeof(s) || throw(SectorMismatch())
8294
offset = 0
83-
for c in sectors(t)
95+
dom = domain(d)[1]
96+
for c in sectors(dom)
8497
if c < s
85-
offset += dim(t, c)
98+
offset += dim(dom, c)
8699
elseif c == s
87-
r = offset .+ (1:dim(t, c))
88-
return Diagonal(view(t.data, r))
100+
r = offset .+ (1:dim(dom, c))
101+
return Diagonal(view(d.data, r))
89102
else # s not in sectors(t)
90-
return Diagonal(view(t.data, 1:0))
103+
return Diagonal(view(d.data, 1:0))
91104
end
92105
end
93106
end
@@ -96,18 +109,130 @@ end
96109

97110
# Indexing and getting and setting the data at the subblock level
98111
#-----------------------------------------------------------------
99-
@inline function Base.getindex(t::DiagonalTensorMap,
112+
@inline function Base.getindex(d::DiagonalTensorMap,
100113
f₁::FusionTree{I,1},
101114
f₂::FusionTree{I,1}) where {I<:Sector}
102115
s = f₁.uncoupled[1]
103-
s == f₁.uncoulped == f₂.uncoupled[1] == f₂.uncoupled || throw(SectorMismatch())
104-
return block(t, s)
116+
s == f₁.coupled == f₂.uncoupled[1] == f₂.coupled || throw(SectorMismatch())
117+
return block(d, s)
105118
# TODO: do we want a StridedView here? Then we need to allocate a new matrix.
106119
end
107120

108-
function Base.setindex!(t::TensorMap,
121+
function Base.setindex!(d::DiagonalTensorMap,
109122
v,
110123
f₁::FusionTree{I,1},
111124
f₂::FusionTree{I,1}) where {I<:Sector}
112-
return copy!(getindex(t, f₁, f₂), v)
125+
return copy!(getindex(d, f₁, f₂), v)
126+
end
127+
128+
function Base.getindex(d::DiagonalTensorMap)
129+
sectortype(d) === Trivial || throw(SectorMismatch())
130+
return Diagonal(d.data)
131+
end
132+
133+
# Index manipulations
134+
# -------------------
135+
function has_shared_permute(d::DiagonalTensorMap, (p₁, p₂)::Index2Tuple)
136+
if p₁ === codomainind(d) && p₂ === domainind(d)
137+
return true
138+
elseif BraidingStyle(sectortype(d)) isa Bosonic
139+
# TODO: is this always correct? transpose has no effect for bosonic sectors?
140+
return p₁ === domainind(d) && p₂ === codomainind(d)
141+
else
142+
return false
143+
end
144+
end
145+
146+
function permute(d::DiagonalTensorMap, (p₁, p₂)::Index2Tuple{N₁,N₂};
147+
copy::Bool=false) where {N₁,N₂}
148+
if !copy
149+
if p₁ === codomainind(d) && p₂ === domainind(d)
150+
return d
151+
elseif has_shared_permute(d, (p₁, p₂)) # tranpose for bosonic sectors
152+
return DiagonalTensorMap(d.data, dual(d.domain))
153+
end
154+
end
155+
# general case
156+
space′ = permute(space(d), (p₁, p₂))
157+
@inbounds begin
158+
return permute!(similar(d, space′), d, (p₁, p₂))
159+
end
160+
end
161+
162+
# VectorInterface
163+
# ---------------
164+
function VectorInterface.zerovector(d::DiagonalTensorMap, ::Type{S}) where {S<:Number}
165+
return DiagonalTensorMap(zerovector(d.data, S), d.domain)
113166
end
167+
function VectorInterface.add(ty::DiagonalTensorMap, tx::DiagonalTensorMap,
168+
α::Number, β::Number)
169+
domain(ty) == domain(tx) || throw(SpaceMismatch("$(space(ty))$(space(tx))"))
170+
T = VectorInterface.promote_add(ty, tx, α, β)
171+
return add!(scale!(zerovector(ty, T), ty, β), tx, α) # zerovector instead of similar preserves diagonal structure
172+
end
173+
174+
# Linear Algebra and factorizations
175+
# ---------------------------------
176+
function one!(d::DiagonalTensorMap)
177+
fill!(d.data, one(eltype(d.data)))
178+
return d
179+
end
180+
function Base.one(d::DiagonalTensorMap)
181+
return DiagonalTensorMap(one.(d.data), d.domain)
182+
end
183+
function Base.zero(d::DiagonalTensorMap)
184+
return DiagonalTensorMap(zero.(d.data), d.domain)
185+
end
186+
187+
function eig!(d::DiagonalTensorMap)
188+
return d, one(d)
189+
end
190+
function eigh!(d::DiagonalTensorMap{<:Real})
191+
return d, one(d)
192+
end
193+
function eigh!(d::DiagonalTensorMap{<:Complex})
194+
# TODO: should this test for hermiticity? `eigh!(::TensorMap)` also does not do this.
195+
return DiagonalTensorMap(real(d.data), d.domain), one(d)
196+
end
197+
198+
function leftorth!(d::DiagonalTensorMap; kwargs...)
199+
return one(d), d # TODO: this is only correct for `alg = QR()` or `alg = QL()`
200+
end
201+
function rightorth!(d::DiagonalTensorMap; kwargs...)
202+
return d, one(d) # TODO: this is only correct for `alg = LQ()` or `alg = RQ()`
203+
end
204+
# not much to do here:
205+
leftnull!(d::DiagonalTensorMap; kwargs...) = leftnull!(TensorMap(d); kwargs...)
206+
rightnull!(d::DiagonalTensorMap; kwargs...) = rightnull!(TensorMap(d); kwargs...)
207+
208+
function tsvd!(d::DiagonalTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
209+
return _tsvd!(d, alg, trunc, p)
210+
end
211+
# helper function
212+
function _compute_svddata!(d::DiagonalTensorMap, alg::Union{SVD,SDD})
213+
InnerProductStyle(d) === EuclideanInnerProduct() || throw_invalid_innerproduct(:tsvd!)
214+
I = sectortype(d)
215+
dims = SectorDict{I,Int}()
216+
generator = Base.Iterators.map(blocks(d)) do (c, b)
217+
lb = length(b.diag)
218+
U = zerovector!(similar(b.diag, lb, lb))
219+
V = zerovector!(similar(b.diag, lb, lb))
220+
p = sortperm(b.diag; by=abs, rev=true)
221+
for (i, pi) in enumerate(p)
222+
U[pi, i] = MatrixAlgebra.safesign(b.diag[pi])
223+
V[i, pi] = 1
224+
end
225+
Σ = abs.(view(b.diag, p))
226+
dims[c] = lb
227+
return c => (U, Σ, V)
228+
end
229+
SVDdata = SectorDict(generator)
230+
return SVDdata, dims
231+
end
232+
233+
# matrix functions
234+
for f in
235+
(:exp, :cos, :sin, :tan, :cot, :cosh, :sinh, :tanh, :coth, :atan, :acot, :asinh, :sqrt,
236+
:log, :asin, :acos, :acosh, :atanh, :acoth)
237+
@eval Base.$f(d::DiagonalTensorMap) = DiagonalTensorMap($f.(d.data), d.domain)
238+
end

src/tensors/indexmanipulations.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,23 @@ function permute(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple{N₁,N₂};
3131
copy::Bool=false) where {N₁,N₂}
3232
space′ = permute(space(t), (p₁, p₂))
3333
# share data if possible
34+
if !copy && p₁ === codomainind(t) && p₂ === domainind(t)
35+
return t
36+
end
37+
# general case
38+
@inbounds begin
39+
return permute!(similar(t, space′), t, (p₁, p₂))
40+
end
41+
end
42+
function permute(t::TensorMap, (p₁, p₂)::Index2Tuple{N₁,N₂}; copy::Bool=false) where {N₁,N₂}
43+
@show (p₁, p₂)
44+
space′ = permute(space(t), (p₁, p₂))
45+
# share data if possible
3446
if !copy
3547
if p₁ === codomainind(t) && p₂ === domainind(t)
3648
return t
3749
elseif has_shared_permute(t, (p₁, p₂))
38-
return TensorMap(reshape(t.data, dim(codomain(space′)), dim(domain(space′))),
39-
codomain(space′), domain(space′))
50+
return TensorMap(t.data, space′)
4051
end
4152
end
4253
# general case
@@ -53,6 +64,9 @@ function permute(t::AbstractTensorMap, p::IndexTuple; copy::Bool=false)
5364
return permute(t, (p, ()); copy)
5465
end
5566

67+
function has_shared_permute(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple)
68+
return (p₁ === codomainind(t) && p₂ === domainind(t))
69+
end
5670
function has_shared_permute(t::TensorMap, (p₁, p₂)::Index2Tuple)
5771
if p₁ === codomainind(t) && p₂ === domainind(t)
5872
return true

src/tensors/vectorinterface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ VectorInterface.zerovector!!(t::AbstractTensorMap) = zerovector!(t)
2323
#-------------------------
2424
function VectorInterface.scale(t::AbstractTensorMap, α::Number)
2525
T = VectorInterface.promote_scale(t, α)
26-
return scale!(similar(t, T), t, α)
26+
return scale!(zerovector(t, T), t, α)
2727
end
2828
function VectorInterface.scale!(t::AbstractTensorMap, α::Number)
2929
for (c, b) in blocks(t)
@@ -74,7 +74,7 @@ function VectorInterface.add!(ty::AbstractTensorMap, tx::AbstractTensorMap,
7474
α::Number, β::Number)
7575
space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty))$(space(tx))"))
7676
for ((cy, by), (cx, bx)) in zip(blocks(ty), blocks(tx))
77-
add!(StridedView(by), StridedView(bx), α, β)
77+
add!(by, bx, α, β)
7878
end
7979
return ty
8080
end

0 commit comments

Comments
 (0)