Skip to content

Commit a8aa774

Browse files
authored
use diagonaltensormap (#190)
* use diagonaltensormap * streamline type and copy
1 parent ef42572 commit a8aa774

File tree

1 file changed

+136
-70
lines changed

1 file changed

+136
-70
lines changed

src/tensors/factorizations.jl

Lines changed: 136 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
# Tensor factorization
22
#----------------------
3+
function factorisation_scalartype(t::AbstractTensorMap)
4+
T = scalartype(t)
5+
return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T)))))
6+
end
7+
factorisation_scalartype(f, t) = factorisation_scalartype(t)
8+
9+
function permutedcopy_oftype(t::AbstractTensorMap, T::Type{<:Number}, p::Index2Tuple)
10+
return permute!(similar(t, T, permute(space(t), p)), t, p)
11+
end
12+
function copy_oftype(t::AbstractTensorMap, T::Type{<:Number})
13+
return copy!(similar(t, T), t)
14+
end
15+
316
"""
417
tsvd(t::AbstractTensorMap, (leftind, rightind)::Index2Tuple;
518
trunc::TruncationScheme = notrunc(), p::Real = 2, alg::Union{SVD, SDD} = SDD())
@@ -36,13 +49,14 @@ algorithm that computes the decomposition (`_gesvd` or `_gesdd`).
3649
Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and `tsvd(!)`
3750
is currently only implemented for `InnerProductStyle(t) === EuclideanInnerProduct()`.
3851
"""
39-
function tsvd(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
40-
return tsvd!(permute(t, (p₁, p₂); copy=true); kwargs...)
52+
function tsvd(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
53+
tcopy = permutedcopy_oftype(t, factorisation_scalartype(tsvd, t), p)
54+
return tsvd!(tcopy; kwargs...)
4155
end
4256

43-
LinearAlgebra.svdvals(t::AbstractTensorMap) = LinearAlgebra.svdvals!(copy(t))
44-
function LinearAlgebra.svdvals!(t::AbstractTensorMap)
45-
return SectorDict(c => LinearAlgebra.svdvals!(b) for (c, b) in blocks(t))
57+
function LinearAlgebra.svdvals(t::AbstractTensorMap)
58+
tcopy = copy_oftype(t, factorisation_scalartype(tsvd, t))
59+
return LinearAlgebra.svdvals!(tcopy)
4660
end
4761

4862
"""
@@ -67,8 +81,9 @@ Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and
6781
`leftorth(!)` is currently only implemented for
6882
`InnerProductStyle(t) === EuclideanInnerProduct()`.
6983
"""
70-
function leftorth(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
71-
return leftorth!(permute(t, (p₁, p₂); copy=true); kwargs...)
84+
function leftorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
85+
tcopy = permutedcopy_oftype(t, factorisation_scalartype(leftorth, t), p)
86+
return leftorth!(tcopy; kwargs...)
7287
end
7388

7489
"""
@@ -95,8 +110,9 @@ Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and
95110
`rightorth(!)` is currently only implemented for
96111
`InnerProductStyle(t) === EuclideanInnerProduct()`.
97112
"""
98-
function rightorth(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
99-
return rightorth!(permute(t, (p₁, p₂); copy=true); kwargs...)
113+
function rightorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
114+
tcopy = permutedcopy_oftype(t, factorisation_scalartype(rightorth, t), p)
115+
return rightorth!(tcopy; kwargs...)
100116
end
101117

102118
"""
@@ -121,8 +137,9 @@ Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and
121137
`leftnull(!)` is currently only implemented for
122138
`InnerProductStyle(t) === EuclideanInnerProduct()`.
123139
"""
124-
function leftnull(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
125-
return leftnull!(permute(t, (p₁, p₂); copy=true); kwargs...)
140+
function leftnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
141+
tcopy = permutedcopy_oftype(t, factorisation_scalartype(leftnull, t), p)
142+
return leftnull!(tcopy; kwargs...)
126143
end
127144

128145
"""
@@ -149,8 +166,9 @@ Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and
149166
`rightnull(!)` is currently only implemented for
150167
`InnerProductStyle(t) === EuclideanInnerProduct()`.
151168
"""
152-
function rightnull(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
153-
return rightnull!(permute(t, (p₁, p₂); copy=true); kwargs...)
169+
function rightnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
170+
tcopy = permutedcopy_oftype(t, factorisation_scalartype(rightnull, t), p)
171+
return rightnull!(tcopy; kwargs...)
154172
end
155173

156174
"""
@@ -172,17 +190,14 @@ matrices. See the corresponding documentation for more information.
172190
173191
See also `eig` and `eigh`
174192
"""
175-
function LinearAlgebra.eigen(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple;
176-
kwargs...)
177-
return eigen!(permute(t, (p₁, p₂); copy=true); kwargs...)
193+
function LinearAlgebra.eigen(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
194+
tcopy = permutedcopy_oftype(t, factorisation_scalartype(eigen, t), p)
195+
return eigen!(tcopy; kwargs...)
178196
end
179197

180198
function LinearAlgebra.eigvals(t::AbstractTensorMap; kwargs...)
181-
return LinearAlgebra.eigvals!(copy(t); kwargs...)
182-
end
183-
function LinearAlgebra.eigvals!(t::AbstractTensorMap; kwargs...)
184-
return SectorDict(c => complex(LinearAlgebra.eigvals!(b; kwargs...))
185-
for (c, b) in blocks(t))
199+
tcopy = copy_oftype(t, factorisation_scalartype(eigen, t))
200+
return LinearAlgebra.eigvals!(tcopy; kwargs...)
186201
end
187202

188203
"""
@@ -207,8 +222,9 @@ matrices. See the corresponding documentation for more information.
207222
208223
See also `eigen` and `eigh`.
209224
"""
210-
function eig(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
211-
return eig!(permute(t, (p₁, p₂); copy=true); kwargs...)
225+
function eig(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
226+
tcopy = permutedcopy_oftype(t, factorisation_scalartype(eig, t), p)
227+
return eig!(tcopy; kwargs...)
212228
end
213229

214230
"""
@@ -231,8 +247,9 @@ permute(t, (leftind, rightind)) * V = V * D
231247
232248
See also `eigen` and `eig`.
233249
"""
234-
function eigh(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple)
235-
return eigh!(permute(t, (p₁, p₂); copy=true))
250+
function eigh(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
251+
tcopy = permutedcopy_oftype(t, factorisation_scalartype(eigh, t), p)
252+
return eigh!(tcopy; kwargs...)
236253
end
237254

238255
"""
@@ -247,31 +264,54 @@ which `isposdef!` is called should have equal domain and codomain, as otherwise
247264
meaningless.
248265
"""
249266
function LinearAlgebra.isposdef(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple)
250-
return isposdef!(permute(t, (p₁, p₂); copy=true))
267+
tcopy = permutedcopy_oftype(t, factorisation_scalartype(isposdef, t), p)
268+
return isposdef!(tcopy)
251269
end
252270

253-
tsvd(t::AbstractTensorMap; kwargs...) = tsvd!(copy(t); kwargs...)
271+
function tsvd(t::AbstractTensorMap; kwargs...)
272+
tcopy = copy!(similar(t, float(scalartype(t))), t)
273+
return tsvd!(tcopy; kwargs...)
274+
end
254275
function leftorth(t::AbstractTensorMap; alg::OFA=QRpos(), kwargs...)
255-
return leftorth!(copy(t); alg=alg, kwargs...)
276+
tcopy = copy!(similar(t, float(scalartype(t))), t)
277+
return leftorth!(tcopy; alg=alg, kwargs...)
256278
end
257279
function rightorth(t::AbstractTensorMap; alg::OFA=LQpos(), kwargs...)
258-
return rightorth!(copy(t); alg=alg, kwargs...)
280+
tcopy = copy!(similar(t, float(scalartype(t))), t)
281+
return rightorth!(tcopy; alg=alg, kwargs...)
259282
end
260283
function leftnull(t::AbstractTensorMap; alg::OFA=QR(), kwargs...)
261-
return leftnull!(copy(t); alg=alg, kwargs...)
284+
tcopy = copy!(similar(t, float(scalartype(t))), t)
285+
return leftnull!(tcopy; alg=alg, kwargs...)
262286
end
263287
function rightnull(t::AbstractTensorMap; alg::OFA=LQ(), kwargs...)
264-
return rightnull!(copy(t); alg=alg, kwargs...)
288+
tcopy = copy!(similar(t, float(scalartype(t))), t)
289+
return rightnull!(tcopy; alg=alg, kwargs...)
290+
end
291+
function LinearAlgebra.eigen(t::AbstractTensorMap; kwargs...)
292+
tcopy = copy!(similar(t, float(scalartype(t))), t)
293+
return eigen!(tcopy; kwargs...)
294+
end
295+
function eig(t::AbstractTensorMap; kwargs...)
296+
tcopy = copy!(similar(t, float(scalartype(t))), t)
297+
return eig!(tcopy; kwargs...)
298+
end
299+
function eigh(t::AbstractTensorMap; kwargs...)
300+
tcopy = copy!(similar(t, float(scalartype(t))), t)
301+
return eigh!(tcopy; kwargs...)
302+
end
303+
function LinearAlgebra.isposdef(t::AbstractTensorMap)
304+
tcopy = copy!(similar(t, float(scalartype(t))), t)
305+
return isposdef!(tcopy)
265306
end
266-
LinearAlgebra.eigen(t::AbstractTensorMap; kwargs...) = eigen!(copy(t); kwargs...)
267-
eig(t::AbstractTensorMap; kwargs...) = eig!(copy(t); kwargs...)
268-
eigh(t::AbstractTensorMap; kwargs...) = eigh!(copy(t); kwargs...)
269-
LinearAlgebra.isposdef(t::AbstractTensorMap) = isposdef!(copy(t))
270307

271308
# Orthogonal factorizations (mutation for recycling memory):
309+
# only possible if scalar type is floating point
272310
# only correct if Euclidean inner product
273311
#------------------------------------------------------------------------------------------
274-
function leftorth!(t::TensorMap;
312+
const RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}}
313+
314+
function leftorth!(t::TensorMap{<:RealOrComplexFloat};
275315
alg::Union{QR,QRpos,QL,QLpos,SVD,SDD,Polar}=QRpos(),
276316
atol::Real=zero(float(real(scalartype(t)))),
277317
rtol::Real=(alg (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
@@ -321,7 +361,7 @@ function leftorth!(t::TensorMap;
321361
return Q, R
322362
end
323363

324-
function leftnull!(t::TensorMap;
364+
function leftnull!(t::TensorMap{<:RealOrComplexFloat};
325365
alg::Union{QR,QRpos,SVD,SDD}=QRpos(),
326366
atol::Real=zero(float(real(scalartype(t)))),
327367
rtol::Real=(alg (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
@@ -360,7 +400,7 @@ function leftnull!(t::TensorMap;
360400
return N
361401
end
362402

363-
function rightorth!(t::TensorMap;
403+
function rightorth!(t::TensorMap{<:RealOrComplexFloat};
364404
alg::Union{LQ,LQpos,RQ,RQpos,SVD,SDD,Polar}=LQpos(),
365405
atol::Real=zero(float(real(scalartype(t)))),
366406
rtol::Real=(alg (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
@@ -410,7 +450,7 @@ function rightorth!(t::TensorMap;
410450
return L, Q
411451
end
412452

413-
function rightnull!(t::TensorMap;
453+
function rightnull!(t::TensorMap{<:RealOrComplexFloat};
414454
alg::Union{LQ,LQpos,SVD,SDD}=LQpos(),
415455
atol::Real=zero(float(real(scalartype(t)))),
416456
rtol::Real=(alg (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
@@ -476,7 +516,13 @@ end
476516
#------------------------------#
477517
# Singular value decomposition #
478518
#------------------------------#
479-
function tsvd!(t::TensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
519+
function LinearAlgebra.svdvals!(t::TensorMap{<:RealOrComplexFloat})
520+
return SectorDict(c => LinearAlgebra.svdvals!(b) for (c, b) in blocks(t))
521+
end
522+
LinearAlgebra.svdvals!(t::AdjointTensorMap) = svdvals!(adjoint(t))
523+
524+
function tsvd!(t::TensorMap{<:RealOrComplexFloat};
525+
trunc=NoTruncation(), p::Real=2, alg=SDD())
480526
return _tsvd!(t, alg, trunc, p)
481527
end
482528
function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
@@ -485,7 +531,8 @@ function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
485531
end
486532

487533
# implementation dispatches on algorithm
488-
function _tsvd!(t, alg::Union{SVD,SDD}, trunc::TruncationScheme, p::Real=2)
534+
function _tsvd!(t::TensorMap{<:RealOrComplexFloat}, alg::Union{SVD,SDD},
535+
trunc::TruncationScheme, p::Real=2)
489536
# early return
490537
if isempty(blocksectors(t))
491538
truncerr = zero(real(scalartype(t)))
@@ -518,13 +565,17 @@ function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD})
518565
return SVDdata, dims
519566
end
520567

521-
function _create_svdtensors(t, SVDdata, dims)
568+
function _create_svdtensors(t::TensorMap{<:RealOrComplexFloat}, SVDdata, dims)
569+
T = scalartype(t)
522570
S = spacetype(t)
523571
W = S(dims)
524-
T = float(scalartype(t))
525-
U = similar(t, T, codomain(t) W)
526-
Σ = similar(t, real(T), W W)
527-
V⁺ = similar(t, T, W domain(t))
572+
573+
Tr = real(T)
574+
A = similarstoragetype(t, Tr)
575+
Σ = DiagonalTensorMap{Tr,S,A}(undef, W)
576+
577+
U = similar(t, codomain(t) W)
578+
V⁺ = similar(t, W domain(t))
528579
for (c, (Uc, Σc, V⁺c)) in SVDdata
529580
r = Base.OneTo(dims[c])
530581
copy!(block(U, c), view(Uc, :, r))
@@ -534,38 +585,53 @@ function _create_svdtensors(t, SVDdata, dims)
534585
return U, Σ, V⁺
535586
end
536587

537-
function _empty_svdtensors(t)
588+
function _empty_svdtensors(t::TensorMap{<:RealOrComplexFloat})
589+
T = scalartype(t)
590+
S = spacetype(t)
538591
I = sectortype(t)
539592
dims = SectorDict{I,Int}()
540-
S = spacetype(t)
541593
W = S(dims)
594+
595+
Tr = real(T)
596+
A = similarstoragetype(t, Tr)
597+
Σ = DiagonalTensorMap{Tr,S,A}(undef, W)
598+
542599
U = similar(t, codomain(t) W)
543-
Σ = similar(t, real(scalartype(t)), W W)
544600
V⁺ = similar(t, W domain(t))
545601
return U, Σ, V⁺
546602
end
547603

548604
#--------------------------#
549605
# Eigenvalue decomposition #
550606
#--------------------------#
551-
LinearAlgebra.eigen!(t::TensorMap) = ishermitian(t) ? eigh!(t) : eig!(t)
607+
function LinearAlgebra.eigen!(t::TensorMap{<:RealOrComplexFloat})
608+
return ishermitian(t) ? eigh!(t) : eig!(t)
609+
end
610+
611+
function LinearAlgebra.eigvals!(t::TensorMap{<:RealOrComplexFloat}; kwargs...)
612+
return SectorDict(c => complex(LinearAlgebra.eigvals!(b; kwargs...))
613+
for (c, b) in blocks(t))
614+
end
615+
function LinearAlgebra.eigvals!(t::AdjointTensorMap{<:RealOrComplexFloat}; kwargs...)
616+
return SectorDict(c => conj!(complex(LinearAlgebra.eigvals!(b; kwargs...)))
617+
for (c, b) in blocks(t))
618+
end
552619

553-
function eigh!(t::TensorMap)
620+
function eigh!(t::TensorMap{<:RealOrComplexFloat})
554621
InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eigh!)
555622
domain(t) == codomain(t) ||
556623
throw(SpaceMismatch("`eigh!` requires domain and codomain to be the same"))
557624

625+
T = scalartype(t)
558626
I = sectortype(t)
627+
S = spacetype(t)
559628
dims = SectorDict{I,Int}(c => size(b, 1) for (c, b) in blocks(t))
560-
if length(domain(t)) == 1
561-
W = domain(t)[1]
562-
else
563-
S = spacetype(t)
564-
W = S(dims)
565-
end
566-
T = float(scalartype(t))
567-
V = similar(t, T, domain(t) W)
568-
D = similar(t, real(T), W W)
629+
W = S(dims)
630+
631+
Tr = real(T)
632+
A = similarstoragetype(t, Tr)
633+
D = DiagonalTensorMap{Tr,S,A}(undef, W)
634+
V = similar(t, domain(t) W)
569635
for (c, b) in blocks(t)
570636
values, vectors = MatrixAlgebra.eigh!(b)
571637
copy!(block(D, c), Diagonal(values))
@@ -574,20 +640,20 @@ function eigh!(t::TensorMap)
574640
return D, V
575641
end
576642

577-
function eig!(t::TensorMap; kwargs...)
643+
function eig!(t::TensorMap{<:RealOrComplexFloat}; kwargs...)
578644
domain(t) == codomain(t) ||
579645
throw(SpaceMismatch("`eig!` requires domain and codomain to be the same"))
646+
647+
T = scalartype(t)
580648
I = sectortype(t)
649+
S = spacetype(t)
581650
dims = SectorDict{I,Int}(c => size(b, 1) for (c, b) in blocks(t))
582-
if length(domain(t)) == 1
583-
W = domain(t)[1]
584-
else
585-
S = spacetype(t)
586-
W = S(dims)
587-
end
588-
T = complex(float(scalartype(t)))
589-
V = similar(t, T, domain(t) W)
590-
D = similar(t, T, W W)
651+
W = S(dims)
652+
653+
Tc = complex(T)
654+
A = similarstoragetype(t, Tc)
655+
D = DiagonalTensorMap{Tc,S,A}(undef, W)
656+
V = similar(t, Tc, domain(t) W)
591657
for (c, b) in blocks(t)
592658
values, vectors = MatrixAlgebra.eig!(b; kwargs...)
593659
copy!(block(D, c), Diagonal(values))

0 commit comments

Comments
 (0)