Skip to content

Commit 4e351d3

Browse files
committed
use diagonaltensormap
1 parent c041bfe commit 4e351d3

File tree

1 file changed

+123
-70
lines changed

1 file changed

+123
-70
lines changed

src/tensors/factorizations.jl

Lines changed: 123 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,14 @@ algorithm that computes the decomposition (`_gesvd` or `_gesdd`).
3636
Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and `tsvd(!)`
3737
is currently only implemented for `InnerProductStyle(t) === EuclideanInnerProduct()`.
3838
"""
39-
function tsvd(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
40-
return tsvd!(permute(t, (p₁, p₂); copy=true); kwargs...)
39+
function tsvd(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
40+
tcopy = permute!(similar(t, float(scalartype(t)), permute(space(t), p)), t, p)
41+
return tsvd!(tcopy; kwargs...)
4142
end
4243

43-
LinearAlgebra.svdvals(t::AbstractTensorMap) = LinearAlgebra.svdvals!(copy(t))
44-
function LinearAlgebra.svdvals!(t::AbstractTensorMap)
45-
return SectorDict(c => LinearAlgebra.svdvals!(b) for (c, b) in blocks(t))
44+
function LinearAlgebra.svdvals(t::AbstractTensorMap)
45+
tcopy = copy!(similar(t, float(scalartype(t))), t)
46+
return LinearAlgebra.svdvals!(tcopy)
4647
end
4748

4849
"""
@@ -67,8 +68,9 @@ Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and
6768
`leftorth(!)` is currently only implemented for
6869
`InnerProductStyle(t) === EuclideanInnerProduct()`.
6970
"""
70-
function leftorth(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
71-
return leftorth!(permute(t, (p₁, p₂); copy=true); kwargs...)
71+
function leftorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
72+
tcopy = permute!(similar(t, float(scalartype(t)), permute(space(t), p)), t, p)
73+
return leftorth!(tcopy; kwargs...)
7274
end
7375

7476
"""
@@ -95,8 +97,9 @@ Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and
9597
`rightorth(!)` is currently only implemented for
9698
`InnerProductStyle(t) === EuclideanInnerProduct()`.
9799
"""
98-
function rightorth(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
99-
return rightorth!(permute(t, (p₁, p₂); copy=true); kwargs...)
100+
function rightorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
101+
tcopy = permute!(similar(t, float(scalartype(t)), permute(space(t), p)), t, p)
102+
return rightorth!(tcopy; kwargs...)
100103
end
101104

102105
"""
@@ -121,8 +124,9 @@ Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and
121124
`leftnull(!)` is currently only implemented for
122125
`InnerProductStyle(t) === EuclideanInnerProduct()`.
123126
"""
124-
function leftnull(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
125-
return leftnull!(permute(t, (p₁, p₂); copy=true); kwargs...)
127+
function leftnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
128+
tcopy = permute!(similar(t, float(scalartype(t)), permute(space(t), p)), t, p)
129+
return leftnull!(tcopy; kwargs...)
126130
end
127131

128132
"""
@@ -149,8 +153,9 @@ Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and
149153
`rightnull(!)` is currently only implemented for
150154
`InnerProductStyle(t) === EuclideanInnerProduct()`.
151155
"""
152-
function rightnull(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
153-
return rightnull!(permute(t, (p₁, p₂); copy=true); kwargs...)
156+
function rightnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
157+
tcopy = permute!(similar(t, float(scalartype(t)), permute(space(t), p)), t, p)
158+
return rightnull!(tcopy; kwargs...)
154159
end
155160

156161
"""
@@ -172,17 +177,14 @@ matrices. See the corresponding documentation for more information.
172177
173178
See also `eig` and `eigh`
174179
"""
175-
function LinearAlgebra.eigen(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple;
176-
kwargs...)
177-
return eigen!(permute(t, (p₁, p₂); copy=true); kwargs...)
180+
function LinearAlgebra.eigen(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
181+
tcopy = permute!(similar(t, float(scalartype(t)), permute(space(t), p)), t, p)
182+
return eigen!(tcopy; kwargs...)
178183
end
179184

180185
function LinearAlgebra.eigvals(t::AbstractTensorMap; kwargs...)
181-
return LinearAlgebra.eigvals!(copy(t); kwargs...)
182-
end
183-
function LinearAlgebra.eigvals!(t::AbstractTensorMap; kwargs...)
184-
return SectorDict(c => complex(LinearAlgebra.eigvals!(b; kwargs...))
185-
for (c, b) in blocks(t))
186+
tcopy = copy!(similar(t, float(scalartype(t))), t)
187+
return LinearAlgebra.eigvals!(tcopy; kwargs...)
186188
end
187189

188190
"""
@@ -207,8 +209,9 @@ matrices. See the corresponding documentation for more information.
207209
208210
See also `eigen` and `eigh`.
209211
"""
210-
function eig(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
211-
return eig!(permute(t, (p₁, p₂); copy=true); kwargs...)
212+
function eig(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
213+
tcopy = permute!(similar(t, float(scalartype(t)), permute(space(t), p)), t, p)
214+
return eig!(tcopy; kwargs...)
212215
end
213216

214217
"""
@@ -231,8 +234,9 @@ permute(t, (leftind, rightind)) * V = V * D
231234
232235
See also `eigen` and `eig`.
233236
"""
234-
function eigh(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple)
235-
return eigh!(permute(t, (p₁, p₂); copy=true))
237+
function eigh(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
238+
tcopy = permute!(similar(t, float(scalartype(t)), permute(space(t), p)), t, p)
239+
return eigh!(tcopy; kwargs...)
236240
end
237241

238242
"""
@@ -247,31 +251,54 @@ which `isposdef!` is called should have equal domain and codomain, as otherwise
247251
meaningless.
248252
"""
249253
function LinearAlgebra.isposdef(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple)
250-
return isposdef!(permute(t, (p₁, p₂); copy=true))
254+
tcopy = permute!(similar(t, float(scalartype(t)), permute(space(t), p)), t, p)
255+
return isposdef!(tcopy)
251256
end
252257

253-
tsvd(t::AbstractTensorMap; kwargs...) = tsvd!(copy(t); kwargs...)
258+
function tsvd(t::AbstractTensorMap; kwargs...)
259+
tcopy = copy!(similar(t, float(scalartype(t))), t)
260+
return tsvd!(tcopy; kwargs...)
261+
end
254262
function leftorth(t::AbstractTensorMap; alg::OFA=QRpos(), kwargs...)
255-
return leftorth!(copy(t); alg=alg, kwargs...)
263+
tcopy = copy!(similar(t, float(scalartype(t))), t)
264+
return leftorth!(tcopy; alg=alg, kwargs...)
256265
end
257266
function rightorth(t::AbstractTensorMap; alg::OFA=LQpos(), kwargs...)
258-
return rightorth!(copy(t); alg=alg, kwargs...)
267+
tcopy = copy!(similar(t, float(scalartype(t))), t)
268+
return rightorth!(tcopy; alg=alg, kwargs...)
259269
end
260270
function leftnull(t::AbstractTensorMap; alg::OFA=QR(), kwargs...)
261-
return leftnull!(copy(t); alg=alg, kwargs...)
271+
tcopy = copy!(similar(t, float(scalartype(t))), t)
272+
return leftnull!(tcopy; alg=alg, kwargs...)
262273
end
263274
function rightnull(t::AbstractTensorMap; alg::OFA=LQ(), kwargs...)
264-
return rightnull!(copy(t); alg=alg, kwargs...)
275+
tcopy = copy!(similar(t, float(scalartype(t))), t)
276+
return rightnull!(tcopy; alg=alg, kwargs...)
277+
end
278+
function LinearAlgebra.eigen(t::AbstractTensorMap; kwargs...)
279+
tcopy = copy!(similar(t, float(scalartype(t))), t)
280+
return eigen!(tcopy; kwargs...)
281+
end
282+
function eig(t::AbstractTensorMap; kwargs...)
283+
tcopy = copy!(similar(t, float(scalartype(t))), t)
284+
return eig!(tcopy; kwargs...)
285+
end
286+
function eigh(t::AbstractTensorMap; kwargs...)
287+
tcopy = copy!(similar(t, float(scalartype(t))), t)
288+
return eigh!(tcopy; kwargs...)
289+
end
290+
function LinearAlgebra.isposdef(t::AbstractTensorMap)
291+
tcopy = copy!(similar(t, float(scalartype(t))), t)
292+
return isposdef!(tcopy)
265293
end
266-
LinearAlgebra.eigen(t::AbstractTensorMap; kwargs...) = eigen!(copy(t); kwargs...)
267-
eig(t::AbstractTensorMap; kwargs...) = eig!(copy(t); kwargs...)
268-
eigh(t::AbstractTensorMap; kwargs...) = eigh!(copy(t); kwargs...)
269-
LinearAlgebra.isposdef(t::AbstractTensorMap) = isposdef!(copy(t))
270294

271295
# Orthogonal factorizations (mutation for recycling memory):
296+
# only possible if scalar type is floating point
272297
# only correct if Euclidean inner product
273298
#------------------------------------------------------------------------------------------
274-
function leftorth!(t::TensorMap;
299+
const RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}}
300+
301+
function leftorth!(t::TensorMap{<:RealOrComplexFloat};
275302
alg::Union{QR,QRpos,QL,QLpos,SVD,SDD,Polar}=QRpos(),
276303
atol::Real=zero(float(real(scalartype(t)))),
277304
rtol::Real=(alg (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
@@ -321,7 +348,7 @@ function leftorth!(t::TensorMap;
321348
return Q, R
322349
end
323350

324-
function leftnull!(t::TensorMap;
351+
function leftnull!(t::TensorMap{<:RealOrComplexFloat};
325352
alg::Union{QR,QRpos,SVD,SDD}=QRpos(),
326353
atol::Real=zero(float(real(scalartype(t)))),
327354
rtol::Real=(alg (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
@@ -360,7 +387,7 @@ function leftnull!(t::TensorMap;
360387
return N
361388
end
362389

363-
function rightorth!(t::TensorMap;
390+
function rightorth!(t::TensorMap{<:RealOrComplexFloat};
364391
alg::Union{LQ,LQpos,RQ,RQpos,SVD,SDD,Polar}=LQpos(),
365392
atol::Real=zero(float(real(scalartype(t)))),
366393
rtol::Real=(alg (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
@@ -410,7 +437,7 @@ function rightorth!(t::TensorMap;
410437
return L, Q
411438
end
412439

413-
function rightnull!(t::TensorMap;
440+
function rightnull!(t::TensorMap{<:RealOrComplexFloat};
414441
alg::Union{LQ,LQpos,SVD,SDD}=LQpos(),
415442
atol::Real=zero(float(real(scalartype(t)))),
416443
rtol::Real=(alg (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
@@ -476,7 +503,13 @@ end
476503
#------------------------------#
477504
# Singular value decomposition #
478505
#------------------------------#
479-
function tsvd!(t::TensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
506+
function LinearAlgebra.svdvals!(t::TensorMap{<:RealOrComplexFloat})
507+
return SectorDict(c => LinearAlgebra.svdvals!(b) for (c, b) in blocks(t))
508+
end
509+
LinearAlgebra.svdvals!(t::AdjointTensorMap) = svdvals!(adjoint(t))
510+
511+
function tsvd!(t::TensorMap{<:RealOrComplexFloat};
512+
trunc=NoTruncation(), p::Real=2, alg=SDD())
480513
return _tsvd!(t, alg, trunc, p)
481514
end
482515
function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
@@ -485,7 +518,8 @@ function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
485518
end
486519

487520
# implementation dispatches on algorithm
488-
function _tsvd!(t, alg::Union{SVD,SDD}, trunc::TruncationScheme, p::Real=2)
521+
function _tsvd!(t::TensorMap{<:RealOrComplexFloat}, alg::Union{SVD,SDD},
522+
trunc::TruncationScheme, p::Real=2)
489523
# early return
490524
if isempty(blocksectors(t))
491525
truncerr = zero(real(scalartype(t)))
@@ -518,13 +552,17 @@ function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD})
518552
return SVDdata, dims
519553
end
520554

521-
function _create_svdtensors(t, SVDdata, dims)
555+
function _create_svdtensors(t::TensorMap{<:RealOrComplexFloat}, SVDdata, dims)
556+
T = scalartype(t)
522557
S = spacetype(t)
523558
W = S(dims)
524-
T = float(scalartype(t))
525-
U = similar(t, T, codomain(t) W)
526-
Σ = similar(t, real(T), W W)
527-
V⁺ = similar(t, T, W domain(t))
559+
560+
Tr = real(T)
561+
A = similarstoragetype(t, Tr)
562+
Σ = DiagonalTensorMap{Tr,S,A}(undef, W)
563+
564+
U = similar(t, codomain(t) W)
565+
V⁺ = similar(t, W domain(t))
528566
for (c, (Uc, Σc, V⁺c)) in SVDdata
529567
r = Base.OneTo(dims[c])
530568
copy!(block(U, c), view(Uc, :, r))
@@ -534,38 +572,53 @@ function _create_svdtensors(t, SVDdata, dims)
534572
return U, Σ, V⁺
535573
end
536574

537-
function _empty_svdtensors(t)
575+
function _empty_svdtensors(t::TensorMap{<:RealOrComplexFloat})
576+
T = scalartype(t)
577+
S = spacetype(t)
538578
I = sectortype(t)
539579
dims = SectorDict{I,Int}()
540-
S = spacetype(t)
541580
W = S(dims)
581+
582+
Tr = real(T)
583+
A = similarstoragetype(t, Tr)
584+
Σ = DiagonalTensorMap{Tr,S,A}(undef, W)
585+
542586
U = similar(t, codomain(t) W)
543-
Σ = similar(t, real(scalartype(t)), W W)
544587
V⁺ = similar(t, W domain(t))
545588
return U, Σ, V⁺
546589
end
547590

548591
#--------------------------#
549592
# Eigenvalue decomposition #
550593
#--------------------------#
551-
LinearAlgebra.eigen!(t::TensorMap) = ishermitian(t) ? eigh!(t) : eig!(t)
594+
function LinearAlgebra.eigen!(t::TensorMap{<:RealOrComplexFloat})
595+
return ishermitian(t) ? eigh!(t) : eig!(t)
596+
end
597+
598+
function LinearAlgebra.eigvals!(t::TensorMap{<:RealOrComplexFloat}; kwargs...)
599+
return SectorDict(c => complex(LinearAlgebra.eigvals!(b; kwargs...))
600+
for (c, b) in blocks(t))
601+
end
602+
function LinearAlgebra.eigvals!(t::AdjointTensorMap{<:RealOrComplexFloat}; kwargs...)
603+
return SectorDict(c => conj!(complex(LinearAlgebra.eigvals!(b; kwargs...)))
604+
for (c, b) in blocks(t))
605+
end
552606

553-
function eigh!(t::TensorMap)
607+
function eigh!(t::TensorMap{<:RealOrComplexFloat})
554608
InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eigh!)
555609
domain(t) == codomain(t) ||
556610
throw(SpaceMismatch("`eigh!` requires domain and codomain to be the same"))
557611

612+
T = scalartype(t)
558613
I = sectortype(t)
614+
S = spacetype(t)
559615
dims = SectorDict{I,Int}(c => size(b, 1) for (c, b) in blocks(t))
560-
if length(domain(t)) == 1
561-
W = domain(t)[1]
562-
else
563-
S = spacetype(t)
564-
W = S(dims)
565-
end
566-
T = float(scalartype(t))
567-
V = similar(t, T, domain(t) W)
568-
D = similar(t, real(T), W W)
616+
W = S(dims)
617+
618+
Tr = real(T)
619+
A = similarstoragetype(t, Tr)
620+
D = DiagonalTensorMap{Tr,S,A}(undef, W)
621+
V = similar(t, domain(t) W)
569622
for (c, b) in blocks(t)
570623
values, vectors = MatrixAlgebra.eigh!(b)
571624
copy!(block(D, c), Diagonal(values))
@@ -574,20 +627,20 @@ function eigh!(t::TensorMap)
574627
return D, V
575628
end
576629

577-
function eig!(t::TensorMap; kwargs...)
630+
function eig!(t::TensorMap{<:RealOrComplexFloat}; kwargs...)
578631
domain(t) == codomain(t) ||
579632
throw(SpaceMismatch("`eig!` requires domain and codomain to be the same"))
633+
634+
T = scalartype(t)
580635
I = sectortype(t)
636+
S = spacetype(t)
581637
dims = SectorDict{I,Int}(c => size(b, 1) for (c, b) in blocks(t))
582-
if length(domain(t)) == 1
583-
W = domain(t)[1]
584-
else
585-
S = spacetype(t)
586-
W = S(dims)
587-
end
588-
T = complex(float(scalartype(t)))
589-
V = similar(t, T, domain(t) W)
590-
D = similar(t, T, W W)
638+
W = S(dims)
639+
640+
Tc = complex(T)
641+
A = similarstoragetype(t, Tc)
642+
D = DiagonalTensorMap{Tc,S,A}(undef, W)
643+
V = similar(t, Tc, domain(t) W)
591644
for (c, b) in blocks(t)
592645
values, vectors = MatrixAlgebra.eig!(b; kwargs...)
593646
copy!(block(D, c), Diagonal(values))

0 commit comments

Comments
 (0)