Skip to content

Commit f7c6f17

Browse files
InterdisciplinaryPhysicsTeampitmonticoneClaudMor
committed
Update tensorsfactorizations.jl
Co-Authored-By: Pietro Monticone <[email protected]> Co-Authored-By: Claudio Moroni <[email protected]>
1 parent 7675103 commit f7c6f17

File tree

1 file changed

+1
-126
lines changed

1 file changed

+1
-126
lines changed

src/tensorsfactorizations.jl

Lines changed: 1 addition & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,6 @@
1-
#= # This script has been copy-pasted from https://github.com/mhauru/TensorFactorizations.jl
1+
# This script has been copy-pasted from https://github.com/mhauru/TensorFactorizations.jl
22

3-
"""
4-
tensorsplit(A, a, b; kwargs...)
5-
6-
Calls tensorsvd with the arguments given to it to decompose the given tensor
7-
A with indices a on one side and indices b on the other. It then splits
8-
the diagonal matrix of singular values into two with a square root and
9-
multiplies these weights into the isometric tensors. Thus tensorsplit ends
10-
up splitting A into two parts, which are then returned, possibly together
11-
with auxiliary data such as a truncation error. If the keyword argument
12-
hermitian=true, an eigenvalue decomposition is used in stead of an SVD. All
13-
the keyword arguments are passed to either tensorsvd or tensoreig.
14-
15-
See tensorsvd and tensoreig for further documentation.
16-
"""
17-
function tensorsplit(args...; kwargs...)
18-
# Find the keyword argument hermitian.
19-
# TODO This is awful, why do I have to do this?
20-
hermitian = false
21-
for (key, value) in kwargs
22-
key == :hermitian && (hermitian = value)
23-
end
243

25-
if hermitian
26-
res = tensoreig(args...; kwargs...)
27-
S, U = res[1:2]
28-
Vt_perm = [ndims(U), (1:(ndims(U) - 1))...]
29-
Vt = conj!(tensorcopy(U, collect(1:ndims(U)), Vt_perm))
30-
S = Diagonal(S)
31-
if !isposdef(S)
32-
S = complex.(S)
33-
end
34-
auxdata = res[3:end]
35-
else
36-
res = tensorsvd(args...; kwargs...)
37-
U, S, Vt = res[1:3]
38-
S = Diagonal(S)
39-
auxdata = res[4:end]
40-
end
41-
S_sqrt = sqrt.(S)
42-
A1 = tensorcontract(U, (1:(ndims(U) - 1)..., :a), S_sqrt, (:a, :b))
43-
A2 = tensorcontract(S_sqrt, (:b, :a), Vt, (:a, 1:(ndims(Vt) - 1)...))
44-
return A1, A2, auxdata...
45-
end =#
464

475
"""
486
tensoreig(A, a, b; chis=nothing, eps=0,
@@ -120,90 +78,7 @@ function tensoreig(
12078
return retval
12179
end
12280

123-
#= """
124-
tensorsvd(A, a, b;
125-
chis=nothing, eps=0,
126-
return_error=false, print_error=false,
127-
break_degenerate=false, degeneracy_eps=1e-6,
128-
norm_type=:frobenius)
129-
130-
Singular valued decomposes a tensor A. The indices of A are
131-
permuted so that the indices listed in the Array/Tuple a are on the "left"
132-
side and indices listed in b are on the "right". The resulting tensor is
133-
then reshaped to a matrix, and this matrix is SVDed into U*diagm(S)*Vt.
134-
Finally, the unitary matrices U and Vt are reshaped to tensors so that
135-
they have a new index coming from the SVD, for U as the last index and for
136-
Vt as the first, and U has indices a as its first indices and V has
137-
indices b as its last indices.
138-
139-
If eps>0 then the SVD may be truncated if the relative error can be kept
140-
below eps. For this purpose different dimensions to truncate to can be tried,
141-
and these dimensions should be listed in chis. If chis is nothing (the
142-
default) then the full range of possible dimensions is tried. If
143-
break_degenerate=false (the default) then the truncation never cuts between
144-
degenerate singular values. degeneracy_eps controls how close the values need
145-
to be to be considered degenerate.
146-
147-
norm_type specifies the norm used to measure the error. This defaults to
148-
:frobenius, which means that the error measured is the Frobenius norm of the
149-
difference between A and the decomposition, divided by the Frobenius norm of
150-
A. This is the same thing as the 2-norm of the singular values that are
151-
truncated out, divided by the 2-norm of all the singular values. The other
152-
option is :trace, in which case a 1-norm is used instead.
153-
154-
If print_error=true the truncation error is printed. The default is false.
155-
156-
If return_error=true then the truncation error is also returned.
157-
The default is false.
15881

159-
Note that no iterative techniques are used, which means choosing to truncate
160-
provides no performance benefits: The full SVD is computed in any case.
161-
162-
Output is U, S, Vt, and possibly error. Here S is a vector of
163-
singular values and U and Vt are isometric tensors (unitary if the matrix
164-
that is SVDed is square and there is no truncation) such that U*diag(S)*Vt =
165-
A, up to truncation errors.
166-
"""
167-
function tensorsvd(
168-
A,
169-
a,
170-
b;
171-
chis=nothing,
172-
eps=0,
173-
return_error=false,
174-
print_error=false,
175-
break_degenerate=false,
176-
degeneracy_eps=1e-6,
177-
norm_type=:frobenius,
178-
)
179-
# Create the matrix and SVD it.
180-
A, shp_a, shp_b = to_matrix(A, a, b; return_tensor_shape=true)
181-
fact = svd(A)
182-
U, S, Vt = fact.U, fact.S, fact.Vt
183-
184-
# Find the dimensions to truncate to and the error caused in doing so.
185-
chi, error = find_trunc_dim(S, chis, eps, break_degenerate, degeneracy_eps, norm_type)
186-
# Truncate
187-
S = S[1:chi]
188-
U = U[:, 1:chi]
189-
Vt = Vt[1:chi, :]
190-
191-
if print_error
192-
println("Relative truncation error ($norm_type norm) in SVD: $error")
193-
end
194-
195-
# Reshape U and V to tensors with shapes matching the shape of A and
196-
# return.
197-
dim = size(S)[1]
198-
U_tens = reshape(U, shp_a..., dim)
199-
Vt_tens = reshape(Vt, dim, shp_b...)
200-
retval = (U_tens, S, Vt_tens)
201-
if return_error
202-
retval = (retval..., error)
203-
end
204-
return retval
205-
end
206-
=#
20782
"""
20883
Format the bond dimensions listed in chis to a standard format.
20984
"""

0 commit comments

Comments
 (0)