Skip to content

Commit efe4ce0

Browse files
committed
streamline type and copy
1 parent 4e351d3 commit efe4ce0

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

src/tensors/factorizations.jl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
# Tensor factorization
22
#----------------------
3+
function factorisation_scalartype(t::AbstractTensorMap)
4+
T = scalartype(t)
5+
return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T)))))
6+
end
7+
factorisation_scalartype(f, t) = factorisation_scalartype(t)
8+
9+
function permutedcopy_oftype(t::AbstractTensorMap, T::Type{<:Number}, p::Index2Tuple)
10+
return permute!(similar(t, T, permute(space(t), p)), t, p)
11+
end
12+
function copy_oftype(t::AbstractTensorMap, T::Type{<:Number})
13+
return copy!(similar(t, T), t)
14+
end
15+
316
"""
417
tsvd(t::AbstractTensorMap, (leftind, rightind)::Index2Tuple;
518
trunc::TruncationScheme = notrunc(), p::Real = 2, alg::Union{SVD, SDD} = SDD())
@@ -37,12 +50,12 @@ Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and `tsvd(!)`
3750
is currently only implemented for `InnerProductStyle(t) === EuclideanInnerProduct()`.
3851
"""
3952
function tsvd(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
40-
tcopy = permute!(similar(t, float(scalartype(t)), permute(space(t), p)), t, p)
53+
tcopy = permutedcopy_oftype(t, factorisation_scalartype(tsvd, t), p)
4154
return tsvd!(tcopy; kwargs...)
4255
end
4356

4457
function LinearAlgebra.svdvals(t::AbstractTensorMap)
45-
tcopy = copy!(similar(t, float(scalartype(t))), t)
58+
tcopy = copy_oftype(t, factorisation_scalartype(tsvd, t))
4659
return LinearAlgebra.svdvals!(tcopy)
4760
end
4861

@@ -69,7 +82,7 @@ Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and
6982
`InnerProductStyle(t) === EuclideanInnerProduct()`.
7083
"""
7184
function leftorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
72-
tcopy = permute!(similar(t, float(scalartype(t)), permute(space(t), p)), t, p)
85+
tcopy = permutedcopy_oftype(t, factorisation_scalartype(leftorth, t), p)
7386
return leftorth!(tcopy; kwargs...)
7487
end
7588

@@ -98,7 +111,7 @@ Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and
98111
`InnerProductStyle(t) === EuclideanInnerProduct()`.
99112
"""
100113
function rightorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
101-
tcopy = permute!(similar(t, float(scalartype(t)), permute(space(t), p)), t, p)
114+
tcopy = permutedcopy_oftype(t, factorisation_scalartype(rightorth, t), p)
102115
return rightorth!(tcopy; kwargs...)
103116
end
104117

@@ -125,7 +138,7 @@ Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and
125138
`InnerProductStyle(t) === EuclideanInnerProduct()`.
126139
"""
127140
function leftnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
128-
tcopy = permute!(similar(t, float(scalartype(t)), permute(space(t), p)), t, p)
141+
tcopy = permutedcopy_oftype(t, factorisation_scalartype(leftnull, t), p)
129142
return leftnull!(tcopy; kwargs...)
130143
end
131144

@@ -154,7 +167,7 @@ Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and
154167
`InnerProductStyle(t) === EuclideanInnerProduct()`.
155168
"""
156169
function rightnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
157-
tcopy = permute!(similar(t, float(scalartype(t)), permute(space(t), p)), t, p)
170+
tcopy = permutedcopy_oftype(t, factorisation_scalartype(rightnull, t), p)
158171
return rightnull!(tcopy; kwargs...)
159172
end
160173

@@ -178,12 +191,12 @@ matrices. See the corresponding documentation for more information.
178191
See also `eig` and `eigh`
179192
"""
180193
function LinearAlgebra.eigen(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
181-
tcopy = permute!(similar(t, float(scalartype(t)), permute(space(t), p)), t, p)
194+
tcopy = permutedcopy_oftype(t, factorisation_scalartype(eigen, t), p)
182195
return eigen!(tcopy; kwargs...)
183196
end
184197

185198
function LinearAlgebra.eigvals(t::AbstractTensorMap; kwargs...)
186-
tcopy = copy!(similar(t, float(scalartype(t))), t)
199+
tcopy = copy_oftype(t, factorisation_scalartype(eigen, t))
187200
return LinearAlgebra.eigvals!(tcopy; kwargs...)
188201
end
189202

@@ -210,7 +223,7 @@ matrices. See the corresponding documentation for more information.
210223
See also `eigen` and `eigh`.
211224
"""
212225
function eig(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
213-
tcopy = permute!(similar(t, float(scalartype(t)), permute(space(t), p)), t, p)
226+
tcopy = permutedcopy_oftype(t, factorisation_scalartype(eig, t), p)
214227
return eig!(tcopy; kwargs...)
215228
end
216229

@@ -235,7 +248,7 @@ permute(t, (leftind, rightind)) * V = V * D
235248
See also `eigen` and `eig`.
236249
"""
237250
function eigh(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
238-
tcopy = permute!(similar(t, float(scalartype(t)), permute(space(t), p)), t, p)
251+
tcopy = permutedcopy_oftype(t, factorisation_scalartype(eigh, t), p)
239252
return eigh!(tcopy; kwargs...)
240253
end
241254

@@ -251,7 +264,7 @@ which `isposdef!` is called should have equal domain and codomain, as otherwise
251264
meaningless.
252265
"""
253266
function LinearAlgebra.isposdef(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple)
254-
tcopy = permute!(similar(t, float(scalartype(t)), permute(space(t), p)), t, p)
267+
tcopy = permutedcopy_oftype(t, factorisation_scalartype(isposdef, t), p)
255268
return isposdef!(tcopy)
256269
end
257270

0 commit comments

Comments
 (0)