Skip to content

Commit 1b9f2c4

Browse files
authored
Add files via upload
Fixed up the docstring only appearing on sweep_contract! and not sweep_contract.
1 parent 0431eb2 commit 1b9f2c4

File tree

1 file changed

+226
-0
lines changed

1 file changed

+226
-0
lines changed

sweep_contract.jl

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
const MPSTensor = Array{Float64,3}
2+
const MPS = Vector{MPSTensor}
3+
4+
# Split a single multi-site tensor into single-site MPS tensors in the dumb way
5+
function splitMPStensor(T::Array{Float64})
6+
v = MPS(undef,ndims(T)-2);
7+
(l,r) = (2,ndims(T)-1)
8+
s = collect(size(T))
9+
while r>l
10+
# Calculate the bond dimensions of sweeping tensor either way and minimise
11+
L = s[l-1]*s[l]
12+
R = s[r]*s[r+1]
13+
if L <= R
14+
v[l-1] = reshape(Matrix{Float64}(LinearAlgebra.I,L,L),(s[l-1],s[l],L));
15+
s[l] *= s[l-1]
16+
l += 1
17+
else
18+
v[r-1] = reshape(Matrix{Float64}(LinearAlgebra.I,R,R),(R,s[r],s[r+1]));
19+
s[r] *= s[r+1]
20+
r -= 1
21+
end
22+
end
23+
v[l-1] = reshape(T,(s[l-1],s[l],s[l+1]))
24+
return v
25+
end
26+
27+
# Truncates down the bond dimension of an MPS, performing a (lossy) compression
28+
function truncMPS!(M::MPS, χ::Int64)
29+
# Put the MPS in canonical form using the QR decomposition, sweeping left-to-right
30+
for i 1:length(M)-1
31+
X = reshape(M[i],(size(M[i],1)*size(M[i],2),size(M[i],3)))
32+
q,r = LinearAlgebra.qr(X)
33+
if size(r,1)==size(r,2)
34+
M[i] = reshape(Matrix(q),size(M[i]))
35+
LinearAlgebra.lmul!(LinearAlgebra.UpperTriangular(r),
36+
reshape(M[i+1],(size(M[i+1],1),size(M[i+1],2)*size(M[i+1],3))))
37+
else
38+
M[i] = reshape(Matrix(q),(size(M[i],1),size(M[i],2),size(r,1)))
39+
M[i+1] = reshape(r*reshape(M[i+1], (size(M[i+1],1),
40+
size(M[i+1],2)*size(M[i+1],3))), (size(r,1),size(M[i+1],2),size(M[i+1],3)))
41+
end
42+
end
43+
# Perform the bond truncation using the SVD decomposition, sweeping right-to-left
44+
for i length(M):-1:2
45+
X = reshape(M[i],(size(M[i],1),size(M[i],2)*size(M[i],3)))
46+
# In some rare cases the default svd can fail to converge
47+
try
48+
F = LinearAlgebra.svd!(X);
49+
(u,s,v) = (F.U,F.S,F.V)
50+
catch _
51+
F = LinearAlgebra.svd!(X; alg=LinearAlgebra.QRIteration())
52+
(u,s,v) = (F.U,F.S,F.V)
53+
end
54+
b = min(length(s),χ)
55+
u = u[:,1:b]
56+
s = s[1:b]
57+
v = v[:,1:b]'
58+
M[i] = reshape(v,(b,size(M[i],2),size(M[i],3)));
59+
X = reshape(M[i-1],(size(M[i-1],1)*size(M[i-1],2),size(M[i-1],3)))*u
60+
LinearAlgebra.rmul!(X,LinearAlgebra.Diagonal(s));
61+
M[i-1] = reshape(X,(size(M[i-1],1),size(M[i-1],2),b));
62+
end
63+
return M
64+
end
65+
66+
# Find the permutation transformation between two vectors
67+
function permutebetween(from, to)
68+
σf = sortperm(from)
69+
σt = sortperm(to)
70+
arr = Vector{Int}(undef,length(from))
71+
for i eachindex(from)
72+
arr[σt[i]]=σf[i]
73+
end
74+
return arr
75+
end
76+
77+
"""
78+
sweep_contract(LTN::LabelledTensorNetwork, χ, τ;
79+
fast=false, valid=false, planar=false, connected=false, report=false)
80+
sweep_contract(TN::TensorNetwork, χ, τ;
81+
fast=false, valid=false, planar=false, connected=false, report=false)
82+
83+
Returns the contraction of the `TensorNetwork TN`, or the `LabelledTensorNetwork LTN` using
84+
the sweepline contraction algorithm of `arXiv:2101.04125`. The MPS is truncated down to a
85+
bond dimension of `χ` whenever any bond dimension exceeds `τ`.
86+
87+
By default the network is checked for validity, planarised, and sweep-connected, where
88+
necessary. The keyword flags `valid`, `planar`, and `connected` can be used to skip these,
89+
or the flag `fast` can be used to skip them all. If these flags are enabled then contraction
90+
may fail on poorly formed networks.
91+
92+
To avoid underflow/overflow issues the contraction value of the network is returned as a
93+
tuple `(f::Float64, i::Int64)` where `1≦f<2` or `f` is `0`, representing a value of `f*2^i`.
94+
The function `ldexp` can be used to convert this back to a Float64.
95+
96+
`sweep_contract` is non-mutating and acts upon a deep copy of the network, where possible
97+
use the more efficient mutating version `sweep_contract!`.
98+
"""
99+
sweep_contract(LTN::LabelledTensorNetwork, χ::Int, τ::Int;
100+
fast=false, valid=false, planar=false, connected=false, report=false) =
101+
sweep_contract!(deepcopy(LTN), χ, τ;
102+
fast=fast, planar=planar, connected=connected, report=report)
103+
104+
sweep_contract(TN::TensorNetwork, χ::Int, τ::Int;
105+
fast=false, valid=false, planar=false, connected=false, report=false) =
106+
sweep_contract!(deepcopy(TN), χ, τ;
107+
fast=fast, planar=planar, connected=connected, report=report)
108+
109+
"""
110+
sweep_contract!(LTN::LabelledTensorNetwork, χ, τ;
111+
fast=false, valid=false, planar=false, connected=false, report=false)
112+
sweep_contract!(TN::TensorNetwork, χ, τ;
113+
fast=false, valid=false, planar=false, connected=false, report=false)
114+
115+
The mutating form of `sweep_contract`.
116+
"""
117+
sweep_contract!(LTN::LabelledTensorNetwork, χ::Int, τ::Int;
118+
fast=false, valid=false, planar=false, connected=false, report=false) =
119+
sweep_contract!(delabel(LTN), χ, τ;
120+
fast=fast, planar=planar, connected=connected, report=report)
121+
122+
function sweep_contract!(TN::TensorNetwork, χ::Int, τ::Int;
123+
fast=false,valid=false,planar=false,connected=false,report=false)::Tuple{Float64,Int}
124+
if !fast
125+
valid || checkvalid(TN)
126+
planar || planarise!(TN)
127+
connected || connect!(hull!(TN))
128+
end
129+
130+
sort!(TN)
131+
132+
N = length(TN)
133+
134+
resexp = 0
135+
count = 0
136+
137+
MPS_t = [ones(1,1,1)]
138+
MPS_i = Int[]
139+
140+
for (i,t) enumerate(TN)
141+
ind_up = Int[]
142+
ind_do = Int[]
143+
for n t.adj
144+
if TN[n]>t
145+
push!(ind_up, n)
146+
elseif TN[n]<t
147+
push!(ind_do, n)
148+
else
149+
throw(InvalidTNError("Overlapping tensors"))
150+
end
151+
end
152+
sort!(ind_up, by=λ->atan(TN[λ].x-t.x,TN[λ].y-t.y))
153+
sort!(ind_do, by=λ->atan(TN[λ].x-t.x,t.y-TN[λ].y))
154+
σ = permutebetween(t.adj, [ind_do; ind_up])
155+
t.arr = permutedims(t.arr, σ)
156+
s = size(t.arr)
157+
t.arr = reshape(t.arr,(prod(s[1:length(ind_do)]),s[length(ind_do)+1:end]...))
158+
159+
if isempty(MPS_i)
160+
MPS_t = splitMPStensor(MPS_t[1][1]*reshape(t.arr,(size(t.arr)...,1)))
161+
MPS_i = ind_up
162+
else
163+
lo = findfirst(isequal(i), MPS_i)
164+
hi = findlast(isequal(i), MPS_i)
165+
166+
isnothing(lo) && throw(InvalidTNError("Disconnected TN"))
167+
168+
X::Array{Float64} = MPS_t[lo]
169+
for j lo+1:hi
170+
finalsize = (size(X,1),size(X,2)*size(MPS_t[j],2),size(MPS_t[j],3))
171+
X = reshape(X,(size(X,1)*size(X,2),size(X,3)))*
172+
reshape(MPS_t[j],(size(MPS_t[j],1),size(MPS_t[j],2)*size(MPS_t[j],3)))
173+
X = reshape(X,finalsize)
174+
end
175+
X = permutedims(X,[1,3,2])
176+
M = reshape(t.arr,(size(t.arr,1),prod(size(t.arr)[2:end])))
177+
X = reshape(
178+
reshape(X,(size(X,1)*size(X,2),size(X,3)))*M,
179+
(size(X,1),size(X,2),size(M,2))
180+
)
181+
X = permutedims(X,[1,3,2])
182+
X = reshape(X,(size(X,1),size(t.arr)[2:end]...,size(X,3)))
183+
184+
MPS_i = [MPS_i[1:lo-1]; ind_up; MPS_i[hi+1:end]]
185+
if ndims(X)!=2
186+
MPS_t = [MPS_t[1:lo-1]; splitMPStensor(X); MPS_t[hi+1:end]]
187+
elseif isempty(MPS_i)
188+
MPS_t=[reshape([X[1]],(1,1,1))]
189+
elseif lo>1
190+
s = size(MPS_t[lo-1])
191+
MPS_t[lo-1] = reshape(
192+
reshape(MPS_t[lo-1],(s[1]*s[2],s[3]))*X,
193+
(s[1],s[2],size(X,2))
194+
)
195+
MPS_t = [MPS_t[1:lo-1]; MPS_t[hi+1:end]]
196+
else
197+
s = size(MPS_t[hi+1])
198+
MPS_t[hi+1] = reshape(
199+
X*reshape(MPS_t[hi+1],(s[1],s[2]*s[3])),
200+
(size(X,1),s[2],s[3])
201+
)
202+
MPS_t = [MPS_t[1:lo-1]; MPS_t[hi+1:end]]
203+
end
204+
205+
if any(size.(MPS_t,3).>τ)
206+
count += 1
207+
truncMPS!(MPS_t, χ)
208+
if LinearAlgebra.norm(MPS_t[1])==0
209+
return (0.0,typemin(Int))
210+
end
211+
h = Int(floor(log2(LinearAlgebra.norm(MPS_t[1]))))
212+
resexp += h
213+
MPS_t[1] /= exp2(h)
214+
end
215+
end
216+
end
217+
218+
report && println("Number of truncations: $count")
219+
220+
res = MPS_t[1][1];
221+
if res == 0.0
222+
return (0.0, typemin(Int64));
223+
end
224+
h = Int(floor(log2(abs(res))));
225+
return (res/exp2(h),resexp+h);
226+
end

0 commit comments

Comments
 (0)