@@ -14,14 +14,14 @@ struct DiagonalTensorMap{T,S<:IndexSpace,A<:DenseVector{T}} <: AbstractTensorMap
1414 function DiagonalTensorMap {T,S,A} (data:: A ,
1515 dom:: S ) where {T,S<: IndexSpace ,A<: DenseVector{T} }
1616 T ⊆ field (S) || @warn (" scalartype(data) = $T ⊈ $(field (S)) )" , maxlog = 1 )
17- return DiagonalTensorMap {T,S,A} (data, dom)
17+ return new {T,S,A} (data, dom)
1818 end
1919end
2020reduceddim (V:: IndexSpace ) = sum (c -> dim (V, c), sectors (V); init= 0 )
2121
2222# Basic methods for characterising a tensor:
2323# --------------------------------------------
24- space (t :: DiagonalTensorMap ) = t . domain ← t . domain
24+ space (d :: DiagonalTensorMap ) = d . domain ← d . domain
2525
2626"""
2727 storagetype(::Union{T,Type{T}}) where {T<:TensorMap} -> Type{A<:DenseVector}
5858
5959# Special case adjoint:
6060# -----------------------
61- Base. adjoint (t :: DiagonalTensorMap{<:Real} ) = t
62- Base. adjoint (t :: DiagonalTensorMap{<:Complex} ) = DiagonalTensorMap (conj (t . data), t . domain)
61+ Base. adjoint (d :: DiagonalTensorMap{<:Real} ) = d
62+ Base. adjoint (d :: DiagonalTensorMap{<:Complex} ) = DiagonalTensorMap (conj (d . data), d . domain)
6363
64- # Efficient copy constructors
65- # -----------------------------
66- Base. copy (t :: DiagonalTensorMap ) = typeof (t )(copy (t . data), t . domain)
64+ # Efficient copy constructors and TensorMap converters
65+ # -----------------------------------------------------
66+ Base. copy (d :: DiagonalTensorMap ) = typeof (d )(copy (d . data), d . domain)
6767
68- function Base. complex (t:: DiagonalTensorMap )
69- if scalartype (t) <: Complex
70- return t
71- else
72- return DiagonalTensorMap (complex (t. data), t. domain)
68+ function Base. copy! (t:: AbstractTensorMap , d:: DiagonalTensorMap )
69+ space (t) == space (d) || throw (SpaceMismatch ())
70+ for (c, b) in blocks (d)
71+ copy! (block (t, c), b)
72+ end
73+ return t
74+ end
75+ TensorMap (d:: DiagonalTensorMap ) = copy! (similar (d), d)
76+ Base. convert (:: Type{TensorMap} , d:: DiagonalTensorMap ) = TensorMap (d)
77+
78+ # Complex, real and imaginary parts
79+ # -----------------------------------
80+ for f in (:real , :imag , :complex )
81+ @eval begin
82+ function Base. $f (d:: DiagonalTensorMap )
83+ return DiagonalTensorMap ($ f (d. data), d. domain)
84+ end
7385 end
7486end
7587
7688# Getting and setting the data at the block level
7789# -------------------------------------------------
78- blocksectors (t :: DiagonalTensorMap ) = blocksectors (t . domain)
90+ blocksectors (d :: DiagonalTensorMap ) = blocksectors (d . domain)
7991
80- function block (t :: DiagonalTensorMap , s:: Sector )
81- sectortype (t ) == typeof (s) || throw (SectorMismatch ())
92+ function block (d :: DiagonalTensorMap , s:: Sector )
93+ sectortype (d ) == typeof (s) || throw (SectorMismatch ())
8294 offset = 0
83- for c in sectors (t)
95+ dom = domain (d)[1 ]
96+ for c in sectors (dom)
8497 if c < s
85- offset += dim (t , c)
98+ offset += dim (dom , c)
8699 elseif c == s
87- r = offset .+ (1 : dim (t , c))
88- return Diagonal (view (t . data, r))
100+ r = offset .+ (1 : dim (dom , c))
101+ return Diagonal (view (d . data, r))
89102 else # s not in sectors(t)
90- return Diagonal (view (t . data, 1 : 0 ))
103+ return Diagonal (view (d . data, 1 : 0 ))
91104 end
92105 end
93106end
@@ -96,18 +109,130 @@ end
96109
97110# Indexing and getting and setting the data at the subblock level
98111# -----------------------------------------------------------------
99- @inline function Base. getindex (t :: DiagonalTensorMap ,
112+ @inline function Base. getindex (d :: DiagonalTensorMap ,
100113 f₁:: FusionTree{I,1} ,
101114 f₂:: FusionTree{I,1} ) where {I<: Sector }
102115 s = f₁. uncoupled[1 ]
103- s == f₁. uncoulped == f₂. uncoupled[1 ] == f₂. uncoupled || throw (SectorMismatch ())
104- return block (t , s)
116+ s == f₁. coupled == f₂. uncoupled[1 ] == f₂. coupled || throw (SectorMismatch ())
117+ return block (d , s)
105118 # TODO : do we want a StridedView here? Then we need to allocate a new matrix.
106119end
107120
108- function Base. setindex! (t :: TensorMap ,
121+ function Base. setindex! (d :: DiagonalTensorMap ,
109122 v,
110123 f₁:: FusionTree{I,1} ,
111124 f₂:: FusionTree{I,1} ) where {I<: Sector }
112- return copy! (getindex (t, f₁, f₂), v)
125+ return copy! (getindex (d, f₁, f₂), v)
126+ end
127+
128+ function Base. getindex (d:: DiagonalTensorMap )
129+ sectortype (d) === Trivial || throw (SectorMismatch ())
130+ return Diagonal (d. data)
131+ end
132+
133+ # Index manipulations
134+ # -------------------
135+ function has_shared_permute (d:: DiagonalTensorMap , (p₁, p₂):: Index2Tuple )
136+ if p₁ === codomainind (d) && p₂ === domainind (d)
137+ return true
138+ elseif BraidingStyle (sectortype (d)) isa Bosonic
139+ # TODO : is this always correct? transpose has no effect for bosonic sectors?
140+ return p₁ === domainind (d) && p₂ === codomainind (d)
141+ else
142+ return false
143+ end
144+ end
145+
146+ function permute (d:: DiagonalTensorMap , (p₁, p₂):: Index2Tuple{N₁,N₂} ;
147+ copy:: Bool = false ) where {N₁,N₂}
148+ if ! copy
149+ if p₁ === codomainind (d) && p₂ === domainind (d)
150+ return d
151+ elseif has_shared_permute (d, (p₁, p₂)) # tranpose for bosonic sectors
152+ return DiagonalTensorMap (d. data, dual (d. domain))
153+ end
154+ end
155+ # general case
156+ space′ = permute (space (d), (p₁, p₂))
157+ @inbounds begin
158+ return permute! (similar (d, space′), d, (p₁, p₂))
159+ end
160+ end
161+
162+ # VectorInterface
163+ # ---------------
164+ function VectorInterface. zerovector (d:: DiagonalTensorMap , :: Type{S} ) where {S<: Number }
165+ return DiagonalTensorMap (zerovector (d. data, S), d. domain)
113166end
167+ function VectorInterface. add (ty:: DiagonalTensorMap , tx:: DiagonalTensorMap ,
168+ α:: Number , β:: Number )
169+ domain (ty) == domain (tx) || throw (SpaceMismatch (" $(space (ty)) ≠ $(space (tx)) " ))
170+ T = VectorInterface. promote_add (ty, tx, α, β)
171+ return add! (scale! (zerovector (ty, T), ty, β), tx, α) # zerovector instead of similar preserves diagonal structure
172+ end
173+
174+ # Linear Algebra and factorizations
175+ # ---------------------------------
176+ function one! (d:: DiagonalTensorMap )
177+ fill! (d. data, one (eltype (d. data)))
178+ return d
179+ end
180+ function Base. one (d:: DiagonalTensorMap )
181+ return DiagonalTensorMap (one .(d. data), d. domain)
182+ end
183+ function Base. zero (d:: DiagonalTensorMap )
184+ return DiagonalTensorMap (zero .(d. data), d. domain)
185+ end
186+
187+ function eig! (d:: DiagonalTensorMap )
188+ return d, one (d)
189+ end
190+ function eigh! (d:: DiagonalTensorMap{<:Real} )
191+ return d, one (d)
192+ end
193+ function eigh! (d:: DiagonalTensorMap{<:Complex} )
194+ # TODO : should this test for hermiticity? `eigh!(::TensorMap)` also does not do this.
195+ return DiagonalTensorMap (real (d. data), d. domain), one (d)
196+ end
197+
198+ function leftorth! (d:: DiagonalTensorMap ; kwargs... )
199+ return one (d), d # TODO : this is only correct for `alg = QR()` or `alg = QL()`
200+ end
201+ function rightorth! (d:: DiagonalTensorMap ; kwargs... )
202+ return d, one (d) # TODO : this is only correct for `alg = LQ()` or `alg = RQ()`
203+ end
204+ # not much to do here:
205+ leftnull! (d:: DiagonalTensorMap ; kwargs... ) = leftnull! (TensorMap (d); kwargs... )
206+ rightnull! (d:: DiagonalTensorMap ; kwargs... ) = rightnull! (TensorMap (d); kwargs... )
207+
208+ function tsvd! (d:: DiagonalTensorMap ; trunc= NoTruncation (), p:: Real = 2 , alg= SDD ())
209+ return _tsvd! (d, alg, trunc, p)
210+ end
211+ # helper function
212+ function _compute_svddata! (d:: DiagonalTensorMap , alg:: Union{SVD,SDD} )
213+ InnerProductStyle (d) === EuclideanInnerProduct () || throw_invalid_innerproduct (:tsvd! )
214+ I = sectortype (d)
215+ dims = SectorDict {I,Int} ()
216+ generator = Base. Iterators. map (blocks (d)) do (c, b)
217+ lb = length (b. diag)
218+ U = zerovector! (similar (b. diag, lb, lb))
219+ V = zerovector! (similar (b. diag, lb, lb))
220+ p = sortperm (b. diag; by= abs, rev= true )
221+ for (i, pi ) in enumerate (p)
222+ U[pi , i] = MatrixAlgebra. safesign (b. diag[pi ])
223+ V[i, pi ] = 1
224+ end
225+ Σ = abs .(view (b. diag, p))
226+ dims[c] = lb
227+ return c => (U, Σ, V)
228+ end
229+ SVDdata = SectorDict (generator)
230+ return SVDdata, dims
231+ end
232+
233+ # matrix functions
234+ for f in
235+ (:exp , :cos , :sin , :tan , :cot , :cosh , :sinh , :tanh , :coth , :atan , :acot , :asinh , :sqrt ,
236+ :log , :asin , :acos , :acosh , :atanh , :acoth )
237+ @eval Base.$ f (d:: DiagonalTensorMap ) = DiagonalTensorMap ($ f .(d. data), d. domain)
238+ end
0 commit comments