|
| 1 | +const MPSTensor = Array{Float64,3} |
| 2 | +const MPS = Vector{MPSTensor} |
| 3 | + |
| 4 | +# Split a single multi-site tensor into single-site MPS tensors in the dumb way |
| 5 | +function splitMPStensor(T::Array{Float64}) |
| 6 | + v = MPS(undef,ndims(T)-2); |
| 7 | + (l,r) = (2,ndims(T)-1) |
| 8 | + s = collect(size(T)) |
| 9 | + while r>l |
| 10 | + # Calculate the bond dimensions of sweeping tensor either way and minimise |
| 11 | + L = s[l-1]*s[l] |
| 12 | + R = s[r]*s[r+1] |
| 13 | + if L <= R |
| 14 | + v[l-1] = reshape(Matrix{Float64}(LinearAlgebra.I,L,L),(s[l-1],s[l],L)); |
| 15 | + s[l] *= s[l-1] |
| 16 | + l += 1 |
| 17 | + else |
| 18 | + v[r-1] = reshape(Matrix{Float64}(LinearAlgebra.I,R,R),(R,s[r],s[r+1])); |
| 19 | + s[r] *= s[r+1] |
| 20 | + r -= 1 |
| 21 | + end |
| 22 | + end |
| 23 | + v[l-1] = reshape(T,(s[l-1],s[l],s[l+1])) |
| 24 | + return v |
| 25 | +end |
| 26 | + |
| 27 | +# Truncates down the bond dimension of an MPS, performing a (lossy) compression |
| 28 | +function truncMPS!(M::MPS, χ::Int64) |
| 29 | + # Put the MPS in canonical form using the QR decomposition, sweeping left-to-right |
| 30 | + for i ∈ 1:length(M)-1 |
| 31 | + X = reshape(M[i],(size(M[i],1)*size(M[i],2),size(M[i],3))) |
| 32 | + q,r = LinearAlgebra.qr(X) |
| 33 | + if size(r,1)==size(r,2) |
| 34 | + M[i] = reshape(Matrix(q),size(M[i])) |
| 35 | + LinearAlgebra.lmul!(LinearAlgebra.UpperTriangular(r), |
| 36 | + reshape(M[i+1],(size(M[i+1],1),size(M[i+1],2)*size(M[i+1],3)))) |
| 37 | + else |
| 38 | + M[i] = reshape(Matrix(q),(size(M[i],1),size(M[i],2),size(r,1))) |
| 39 | + M[i+1] = reshape(r*reshape(M[i+1], (size(M[i+1],1), |
| 40 | + size(M[i+1],2)*size(M[i+1],3))), (size(r,1),size(M[i+1],2),size(M[i+1],3))) |
| 41 | + end |
| 42 | + end |
| 43 | + # Perform the bond truncation using the SVD decomposition, sweeping right-to-left |
| 44 | + for i ∈ length(M):-1:2 |
| 45 | + X = reshape(M[i],(size(M[i],1),size(M[i],2)*size(M[i],3))) |
| 46 | + # In some rare cases the default svd can fail to converge |
| 47 | + try |
| 48 | + F = LinearAlgebra.svd!(X); |
| 49 | + (u,s,v) = (F.U,F.S,F.V) |
| 50 | + catch _ |
| 51 | + F = LinearAlgebra.svd!(X; alg=LinearAlgebra.QRIteration()) |
| 52 | + (u,s,v) = (F.U,F.S,F.V) |
| 53 | + end |
| 54 | + b = min(length(s),χ) |
| 55 | + u = u[:,1:b] |
| 56 | + s = s[1:b] |
| 57 | + v = v[:,1:b]' |
| 58 | + M[i] = reshape(v,(b,size(M[i],2),size(M[i],3))); |
| 59 | + X = reshape(M[i-1],(size(M[i-1],1)*size(M[i-1],2),size(M[i-1],3)))*u |
| 60 | + LinearAlgebra.rmul!(X,LinearAlgebra.Diagonal(s)); |
| 61 | + M[i-1] = reshape(X,(size(M[i-1],1),size(M[i-1],2),b)); |
| 62 | + end |
| 63 | + return M |
| 64 | +end |
| 65 | + |
| 66 | +# Find the permutation transformation between two vectors |
| 67 | +function permutebetween(from, to) |
| 68 | + σf = sortperm(from) |
| 69 | + σt = sortperm(to) |
| 70 | + arr = Vector{Int}(undef,length(from)) |
| 71 | + for i ∈ eachindex(from) |
| 72 | + arr[σt[i]]=σf[i] |
| 73 | + end |
| 74 | + return arr |
| 75 | +end |
| 76 | + |
| 77 | +""" |
| 78 | + sweep_contract(LTN::LabelledTensorNetwork, χ, τ; |
| 79 | + fast=false, valid=false, planar=false, connected=false, report=false) |
| 80 | + sweep_contract(TN::TensorNetwork, χ, τ; |
| 81 | + fast=false, valid=false, planar=false, connected=false, report=false) |
| 82 | +
|
| 83 | +Returns the contraction of the `TensorNetwork TN`, or the `LabelledTensorNetwork LTN` using |
| 84 | +the sweepline contraction algorithm of `arXiv:2101.04125`. The MPS is truncated down to a |
| 85 | +bond dimension of `χ` whenever any bond dimension exceeds `τ`. |
| 86 | +
|
| 87 | +By default the network is checked for validity, planarised, and sweep-connected, where |
| 88 | +necessary. The keyword flags `valid`, `planar`, and `connected` can be used to skip these, |
| 89 | +or the flag `fast` can be used to skip them all. If these flags are enabled then contraction |
| 90 | +may fail on poorly formed networks. |
| 91 | +
|
| 92 | +To avoid underflow/overflow issues the contraction value of the network is returned as a |
| 93 | +tuple `(f::Float64, i::Int64)` where `1≦f<2` or `f` is `0`, representing a value of `f*2^i`. |
| 94 | +The function `ldexp` can be used to convert this back to a Float64. |
| 95 | +
|
| 96 | +`sweep_contract` is non-mutating and acts upon a deep copy of the network, where possible |
| 97 | +use the more efficient mutating version `sweep_contract!`. |
| 98 | +""" |
| 99 | +sweep_contract(LTN::LabelledTensorNetwork, χ::Int, τ::Int; |
| 100 | + fast=false, valid=false, planar=false, connected=false, report=false) = |
| 101 | +sweep_contract!(deepcopy(LTN), χ, τ; |
| 102 | + fast=fast, planar=planar, connected=connected, report=report) |
| 103 | + |
| 104 | +sweep_contract(TN::TensorNetwork, χ::Int, τ::Int; |
| 105 | + fast=false, valid=false, planar=false, connected=false, report=false) = |
| 106 | +sweep_contract!(deepcopy(TN), χ, τ; |
| 107 | + fast=fast, planar=planar, connected=connected, report=report) |
| 108 | + |
| 109 | +""" |
| 110 | + sweep_contract!(LTN::LabelledTensorNetwork, χ, τ; |
| 111 | + fast=false, valid=false, planar=false, connected=false, report=false) |
| 112 | + sweep_contract!(TN::TensorNetwork, χ, τ; |
| 113 | + fast=false, valid=false, planar=false, connected=false, report=false) |
| 114 | +
|
| 115 | +The mutating form of `sweep_contract`. |
| 116 | +""" |
| 117 | +sweep_contract!(LTN::LabelledTensorNetwork, χ::Int, τ::Int; |
| 118 | + fast=false, valid=false, planar=false, connected=false, report=false) = |
| 119 | +sweep_contract!(delabel(LTN), χ, τ; |
| 120 | + fast=fast, planar=planar, connected=connected, report=report) |
| 121 | + |
| 122 | +function sweep_contract!(TN::TensorNetwork, χ::Int, τ::Int; |
| 123 | + fast=false,valid=false,planar=false,connected=false,report=false)::Tuple{Float64,Int} |
| 124 | + if !fast |
| 125 | + valid || checkvalid(TN) |
| 126 | + planar || planarise!(TN) |
| 127 | + connected || connect!(hull!(TN)) |
| 128 | + end |
| 129 | + |
| 130 | + sort!(TN) |
| 131 | + |
| 132 | + N = length(TN) |
| 133 | + |
| 134 | + resexp = 0 |
| 135 | + count = 0 |
| 136 | + |
| 137 | + MPS_t = [ones(1,1,1)] |
| 138 | + MPS_i = Int[] |
| 139 | + |
| 140 | + for (i,t) ∈ enumerate(TN) |
| 141 | + ind_up = Int[] |
| 142 | + ind_do = Int[] |
| 143 | + for n ∈ t.adj |
| 144 | + if TN[n]>t |
| 145 | + push!(ind_up, n) |
| 146 | + elseif TN[n]<t |
| 147 | + push!(ind_do, n) |
| 148 | + else |
| 149 | + throw(InvalidTNError("Overlapping tensors")) |
| 150 | + end |
| 151 | + end |
| 152 | + sort!(ind_up, by=λ->atan(TN[λ].x-t.x,TN[λ].y-t.y)) |
| 153 | + sort!(ind_do, by=λ->atan(TN[λ].x-t.x,t.y-TN[λ].y)) |
| 154 | + σ = permutebetween(t.adj, [ind_do; ind_up]) |
| 155 | + t.arr = permutedims(t.arr, σ) |
| 156 | + s = size(t.arr) |
| 157 | + t.arr = reshape(t.arr,(prod(s[1:length(ind_do)]),s[length(ind_do)+1:end]...)) |
| 158 | + |
| 159 | + if isempty(MPS_i) |
| 160 | + MPS_t = splitMPStensor(MPS_t[1][1]*reshape(t.arr,(size(t.arr)...,1))) |
| 161 | + MPS_i = ind_up |
| 162 | + else |
| 163 | + lo = findfirst(isequal(i), MPS_i) |
| 164 | + hi = findlast(isequal(i), MPS_i) |
| 165 | + |
| 166 | + isnothing(lo) && throw(InvalidTNError("Disconnected TN")) |
| 167 | + |
| 168 | + X::Array{Float64} = MPS_t[lo] |
| 169 | + for j ∈ lo+1:hi |
| 170 | + finalsize = (size(X,1),size(X,2)*size(MPS_t[j],2),size(MPS_t[j],3)) |
| 171 | + X = reshape(X,(size(X,1)*size(X,2),size(X,3)))* |
| 172 | + reshape(MPS_t[j],(size(MPS_t[j],1),size(MPS_t[j],2)*size(MPS_t[j],3))) |
| 173 | + X = reshape(X,finalsize) |
| 174 | + end |
| 175 | + X = permutedims(X,[1,3,2]) |
| 176 | + M = reshape(t.arr,(size(t.arr,1),prod(size(t.arr)[2:end]))) |
| 177 | + X = reshape( |
| 178 | + reshape(X,(size(X,1)*size(X,2),size(X,3)))*M, |
| 179 | + (size(X,1),size(X,2),size(M,2)) |
| 180 | + ) |
| 181 | + X = permutedims(X,[1,3,2]) |
| 182 | + X = reshape(X,(size(X,1),size(t.arr)[2:end]...,size(X,3))) |
| 183 | + |
| 184 | + MPS_i = [MPS_i[1:lo-1]; ind_up; MPS_i[hi+1:end]] |
| 185 | + if ndims(X)!=2 |
| 186 | + MPS_t = [MPS_t[1:lo-1]; splitMPStensor(X); MPS_t[hi+1:end]] |
| 187 | + elseif isempty(MPS_i) |
| 188 | + MPS_t=[reshape([X[1]],(1,1,1))] |
| 189 | + elseif lo>1 |
| 190 | + s = size(MPS_t[lo-1]) |
| 191 | + MPS_t[lo-1] = reshape( |
| 192 | + reshape(MPS_t[lo-1],(s[1]*s[2],s[3]))*X, |
| 193 | + (s[1],s[2],size(X,2)) |
| 194 | + ) |
| 195 | + MPS_t = [MPS_t[1:lo-1]; MPS_t[hi+1:end]] |
| 196 | + else |
| 197 | + s = size(MPS_t[hi+1]) |
| 198 | + MPS_t[hi+1] = reshape( |
| 199 | + X*reshape(MPS_t[hi+1],(s[1],s[2]*s[3])), |
| 200 | + (size(X,1),s[2],s[3]) |
| 201 | + ) |
| 202 | + MPS_t = [MPS_t[1:lo-1]; MPS_t[hi+1:end]] |
| 203 | + end |
| 204 | + |
| 205 | + if any(size.(MPS_t,3).>τ) |
| 206 | + count += 1 |
| 207 | + truncMPS!(MPS_t, χ) |
| 208 | + if LinearAlgebra.norm(MPS_t[1])==0 |
| 209 | + return (0.0,typemin(Int)) |
| 210 | + end |
| 211 | + h = Int(floor(log2(LinearAlgebra.norm(MPS_t[1])))) |
| 212 | + resexp += h |
| 213 | + MPS_t[1] /= exp2(h) |
| 214 | + end |
| 215 | + end |
| 216 | + end |
| 217 | + |
| 218 | + report && println("Number of truncations: $count") |
| 219 | + |
| 220 | + res = MPS_t[1][1]; |
| 221 | + if res == 0.0 |
| 222 | + return (0.0, typemin(Int64)); |
| 223 | + end |
| 224 | + h = Int(floor(log2(abs(res)))); |
| 225 | + return (res/exp2(h),resexp+h); |
| 226 | +end |
0 commit comments