Skip to content

Commit 7865d5e

Browse files
lkdvosYue-Zhengyuankshyatt
authored
Belief propagation gauge fixing (#223)
* Add `BPEnv` * Add `BPEnv` constructors * Add BP contractions * Add BP VectorInterface support * Add BP nearest neighbor expectation values * Add bp normalization * Move contractions to separate file * Add BP fixed point * Add `gauge_fix` function * Remove repeated `import VectorInterface` * Format * Kill ProductSpaceLike * Try to add some simple tests * Add BP to the test workflow too and format * Death to VectorInterface * Some 0.7 updates * Format * Absorb sqrt(weights) into vertex tensors [skip ci] * Add docstring for BPEnv [skip ci] * export BPEnv, BeliefPropagation * Fix bond mismatch and enforce Hermiticity in `bp_iteration` * Add test to compare BP and SU fixed point * Fix BP gauge fixing * Fix formatting * Preserve virtual arrows with flip_svd * Implement rotation of BPEnv * Add BPEnv to CTMRGEnv conversion [skip ci] * Separate bp_fixedpoint and gauge_fix * Rename test file * Add docstring for BPEnv constructors * Add Base.size * Add BP expectation value test * Fix unitcell test * add some type properties, simplify CTMRGEnv converter * make hermiticity an algorithm parameter * improve docstrings, add miniter * normalize BP error criterion * small fixes * refactor gauge fixing * improve type stability * initialize PEPS messages with identity * return gauges * small fixes * Adapt to removal of flip_svd * Separate BPGauge from BP * Bring back `SUWeight(::BPEnv)` * Move `random_dual!` to utils; add fermion gauge fix test * Refactor BPEnv constructor * Move trivial SU gauging to src * Add `BPEnv(::SUWeight)` * Change west, south message axis order * Temp. remove fermion gauging test; add non-hermitian test * Automatically check Hermiticity of messages when fixing gauge * Add `posdef` option for BPEnv constructors * Handle posdef-ness of fermionic messages * Fix BP for fermions with non-standard virtual arrows * Update docstrings and error messages * Update docstrings again --------- Co-authored-by: Yue Zhengyuan <[email protected]> Co-authored-by: Katharine Hyatt <[email protected]>
1 parent 10559ba commit 7865d5e

File tree

15 files changed

+1087
-7
lines changed

15 files changed

+1087
-7
lines changed

.github/workflows/Tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ jobs:
2626
- 'lts' # minimal supported version
2727
- '1'
2828
group:
29+
- bp
2930
- types
3031
- ctmrg
3132
- boundarymps

src/PEPSKit.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,13 @@ include("operators/models.jl")
5151
include("environments/ctmrg_environments.jl")
5252
include("environments/vumps_environments.jl")
5353
include("environments/suweight.jl")
54+
include("environments/bp_environments.jl")
5455

5556
include("algorithms/contractions/ctmrg_contractions.jl")
5657
include("algorithms/contractions/transfer.jl")
5758
include("algorithms/contractions/localoperator.jl")
5859
include("algorithms/contractions/vumps_contractions.jl")
60+
include("algorithms/contractions/bp_contractions.jl")
5961
include("algorithms/contractions/bondenv/benv_tools.jl")
6062
include("algorithms/contractions/bondenv/gaugefix.jl")
6163
include("algorithms/contractions/bondenv/als_solve.jl")
@@ -78,6 +80,10 @@ include("algorithms/time_evolution/evoltools.jl")
7880
include("algorithms/time_evolution/time_evolve.jl")
7981
include("algorithms/time_evolution/simpleupdate.jl")
8082
include("algorithms/time_evolution/simpleupdate3site.jl")
83+
include("algorithms/time_evolution/gaugefix_su.jl")
84+
85+
include("algorithms/bp/beliefpropagation.jl")
86+
include("algorithms/bp/gaugefix.jl")
8187

8288
include("algorithms/transfermatrix.jl")
8389
include("algorithms/toolbox.jl")
@@ -114,6 +120,9 @@ export InfinitePartitionFunction
114120
export InfinitePEPS, InfiniteTransferPEPS
115121
export SUWeight
116122
export InfinitePEPO, InfiniteTransferPEPO
123+
124+
export BPEnv, BeliefPropagation, BPGauge
125+
117126
export initialize_mps, initializePEPS
118127
export ReflectDepth, ReflectWidth, Rotate, RotateReflect
119128
export symmetrize!, symmetrize_retract_and_finalize!
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""
2+
struct BeliefPropagation
3+
4+
Algorithm for computing the belief propagation fixed point messages.
5+
6+
## Fields
7+
8+
$(TYPEDFIELDS)
9+
"""
10+
@kwdef struct BeliefPropagation
11+
"Stopping criterion for the BP iterations in relative trace norm difference"
12+
tol::Float64 = 1.0e-6
13+
14+
"Minimal number of BP iterations"
15+
miniter::Int = 2
16+
17+
"Maximal number of BP iterations"
18+
maxiter::Int = 50
19+
20+
"Toggle for projecting messages onto the hermitian subspace immediately after update through BP equation"
21+
project_hermitian::Bool = true
22+
23+
"Output verbosity level"
24+
verbosity::Int = 2
25+
end
26+
27+
"""
28+
leading_boundary(env₀::BPEnv, network, alg::BeliefPropagation)
29+
30+
Contract `network` in the BP approximation and return the corresponding messages.
31+
"""
32+
function leading_boundary(env₀::BPEnv, network::InfiniteSquareNetwork, alg::BeliefPropagation)
33+
return LoggingExtras.withlevel(; alg.verbosity) do
34+
env = deepcopy(env₀)
35+
log = MPSKit.IterLog("BP")
36+
ϵ = Inf
37+
@infov 1 loginit!(log, ϵ)
38+
for iter in 1:(alg.maxiter)
39+
env′ = bp_iteration(network, env, alg)
40+
ϵ = oftype(ϵ, tr_distance(env, env′))
41+
env = env′
42+
43+
if ϵ <= alg.tol && iter >= alg.miniter
44+
@infov 2 logfinish!(log, iter, ϵ)
45+
break
46+
end
47+
if iter alg.maxiter
48+
@warnv 1 logcancel!(log, iter, ϵ)
49+
else
50+
@infov 3 logiter!(log, iter, ϵ)
51+
end
52+
end
53+
54+
return env, ϵ
55+
end
56+
end
57+
function leading_boundary(env₀::BPEnv, state, args...)
58+
return leading_boundary(env₀, InfiniteSquareNetwork(state), args...)
59+
end
60+
61+
"""
62+
One iteration to update the BP environment.
63+
"""
64+
function bp_iteration(network::InfiniteSquareNetwork, env::BPEnv, alg::BeliefPropagation)
65+
messages = map(eachindex(env)) do I
66+
M = update_message(I, network, env)
67+
normalize!(M)
68+
alg.project_hermitian && (M = project_hermitian!!(M))
69+
return M
70+
end
71+
return BPEnv(messages)
72+
end
73+
74+
"""
75+
Update the BP message in `env.messages[I]`.
76+
"""
77+
function update_message(I::CartesianIndex{3}, network::InfiniteSquareNetwork, env::BPEnv)
78+
dir, row, col = Tuple(I)
79+
(1 <= dir <= 4) || throw(ArgumentError("Invalid direction $dir"))
80+
81+
A = network[row, col]
82+
dir == SOUTH || (M_north = env[NORTH, _prev(row, end), col])
83+
dir == WEST || (M_east = env[EAST, row, _next(col, end)])
84+
dir == NORTH || (M_south = env[SOUTH, _next(row, end), col])
85+
dir == EAST || (M_west = env[WEST, row, _prev(col, end)])
86+
87+
return if dir == NORTH
88+
contract_north_message(A, M_west, M_north, M_east)
89+
elseif dir == EAST
90+
contract_east_message(A, M_north, M_east, M_south)
91+
elseif dir == SOUTH
92+
contract_south_message(A, M_east, M_south, M_west)
93+
else # dir == WEST
94+
contract_west_message(A, M_south, M_west, M_north)
95+
end
96+
end
97+
98+
function tr_distance(A::BPEnv, B::BPEnv)
99+
return sum(zip(A.messages, B.messages)) do (a, b)
100+
return trnorm(add(a, b, -inv(tr(b)), inv(tr(a))))
101+
end / length(A.messages)
102+
end
103+
104+
function trnorm(M::AbstractTensorMap, p::Real = 1)
105+
return TensorKit._norm(svdvals(M), p, zero(real(scalartype(M))))
106+
end
107+
function trnorm!(M::AbstractTensorMap, p::Real = 1)
108+
return TensorKit._norm(svdvals!(M), p, zero(real(scalartype(M))))
109+
end
110+
111+
project_hermitian!!(t) = add(t, t', 1 / 2, 1 / 2)

src/algorithms/bp/gaugefix.jl

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""
2+
struct BPGauge
3+
4+
Algorithm for gauging PEPS with belief propagation fixed point messages.
5+
"""
6+
@kwdef struct BPGauge
7+
# TODO: add options
8+
end
9+
10+
"""
11+
$(SIGNATURES)
12+
13+
Fix the gauge of `psi` using fixed point environment `env` of belief propagation.
14+
"""
15+
function gauge_fix(psi::InfinitePEPS, alg::BPGauge, env::BPEnv)
16+
psi′ = copy(psi)
17+
XXinv = map(eachcoordinate(psi, 1:2)) do I
18+
_, X, Xinv = _bp_gauge_fix!(CartesianIndex(I), psi′, env)
19+
return X, Xinv
20+
end
21+
return psi′, XXinv
22+
end
23+
24+
function _sqrt_bp_messages(I::CartesianIndex{3}, env::BPEnv)
25+
dir, row, col = Tuple(I)
26+
@assert dir == NORTH || dir == EAST
27+
M12 = env[dir, dir == NORTH ? _prev(row, end) : row, dir == EAST ? _next(col, end) : col]
28+
sqrtM12, isqrtM12 = sqrt_invsqrt(twist(M12, 1))
29+
M21 = env[dir + 2, row, col]
30+
sqrtM21, isqrtM21 = sqrt_invsqrt(M21)
31+
return sqrtM12, isqrtM12, sqrtM21, isqrtM21
32+
end
33+
34+
"""
35+
_bp_gauge_fix!(I, psi::InfinitePEPS, env::BPEnv) -> psi, X, X⁻¹
36+
37+
For the bond at direction `I[1]` (which can be `NORTH` or `EAST`)
38+
from site `I[2], I[3]`, we identify the following gauge matrices,
39+
along the canonical direction of the PEPS arrows (`SOUTH ← NORTH` or `WEST ← EAST`):
40+
41+
```math
42+
I = √M₁₂⁻¹ √M₁₂ √M₂₁ √M₂₁⁻¹
43+
= √M₁₂⁻¹ (U Λ Vᴴ) √M₂₁⁻¹
44+
= (√M₁₂⁻¹ U √Λ) (√Λ Vᴴ √M₂₁⁻¹)
45+
= X X⁻¹
46+
```
47+
48+
Which are then used to update the gauge of `psi`. Thus, by convention `X` is attached to the `SOUTH`/`WEST` directions
49+
and `X⁻¹` is attached to the `NORTH`/`EAST` directions.
50+
"""
51+
function _bp_gauge_fix!(I::CartesianIndex{3}, psi::InfinitePEPS, env::BPEnv)
52+
dir, row, col = Tuple(I)
53+
@assert dir == NORTH || dir == EAST
54+
55+
sqrtM12, isqrtM12, sqrtM21, isqrtM21 = _sqrt_bp_messages(I, env)
56+
U, Λ, Vᴴ = svd_compact!(sqrtM12 * sqrtM21)
57+
sqrtΛ = sdiag_pow(Λ, 1 / 2)
58+
X = isqrtM12 * U * sqrtΛ
59+
invX = sqrtΛ * Vᴴ * isqrtM21
60+
if isdual(space(sqrtM12, 1))
61+
X, invX = twist(flip(X, 2), 1), flip(invX, 1)
62+
end
63+
if dir == NORTH
64+
psi[row, col] = absorb_north_message(psi[row, col], X)
65+
psi[_prev(row, end), col] = absorb_south_message(psi[_prev(row, end), col], invX)
66+
elseif dir == EAST
67+
psi[row, col] = absorb_east_message(psi[row, col], X)
68+
psi[row, _next(col, end)] = absorb_west_message(psi[row, _next(col, end)], invX)
69+
end
70+
return psi, X, invX
71+
end
72+
73+
"""
74+
SUWeight(env::BPEnv)
75+
76+
Construct `SUWeight` from belief propagation fixed point environment `env`.
77+
"""
78+
function SUWeight(env::BPEnv)
79+
wts = map(Iterators.product(1:2, axes(env, 2), axes(env, 3))) do (dir′, row, col)
80+
I = CartesianIndex(mod1(dir′ + 1, 2), row, col)
81+
sqrtM12, _, sqrtM21, _ = _sqrt_bp_messages(I, env)
82+
Λ = svd_vals!(sqrtM12 * sqrtM21)
83+
return isdual(space(sqrtM12, 1)) ? _fliptwist_s(Λ) : Λ
84+
end
85+
return SUWeight(wts)
86+
end
87+
88+
"""
89+
BPEnv(wts::SUWeight)
90+
91+
Convert fixed point weights `wts` of trivial simple update
92+
to a belief propagation environment.
93+
"""
94+
function BPEnv(wts::SUWeight)
95+
messages = map(Iterators.product(1:4, axes(wts, 2), axes(wts, 3))) do (d, r, c)
96+
wt = if d == NORTH
97+
twist(wts[2, _next(r, end), c], 1)
98+
elseif d == EAST
99+
twist(wts[1, r, _prev(c, end)], 1)
100+
elseif d == SOUTH
101+
copy(wts[2, r, c])
102+
else # WEST
103+
copy(wts[1, r, c])
104+
end
105+
return TensorMap(wt)
106+
end
107+
return BPEnv(messages)
108+
end
109+
110+
function sqrt_invsqrt(A::PEPSMessage)
111+
if isposdef(A)
112+
D, V = eigh_full(A)
113+
sqrtA = V * sdiag_pow(D, 1 / 2) * V'
114+
isqrtA = V * sdiag_pow(D, -1 / 2) * V'
115+
else
116+
D, V = eig_full(A)
117+
V⁻¹ = inv(V)
118+
sqrtA = V * sdiag_pow(D, 1 / 2) * V⁻¹
119+
isqrtA = V * sdiag_pow(D, -1 / 2) * V⁻¹
120+
if scalartype(A) <: Real
121+
# TODO: is this valid?
122+
sqrtA = real(sqrtA)
123+
isqrtA = real(isqrtA)
124+
end
125+
end
126+
return sqrtA, isqrtA
127+
end

0 commit comments

Comments
 (0)