Skip to content

Commit b25d9f1

Browse files
authored
Fix gauge fixing in :fixed mode for non-uniform unit cells from full SVD (#249)
- Fix unit cell indices of extended sign spaces in `_fix_svd_algorithm` - Add gauge fixing test for non-uniform unit cells - Force `FixedSpaceTruncation` in `:fixed` `GradMode` before gauge fixing
1 parent 1a27f11 commit b25d9f1

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

src/algorithms/optimization/fixed_point_differentiation.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,8 @@ function _rrule(
248248
alg::SimultaneousCTMRG,
249249
)
250250
env, = leading_boundary(envinit, state, alg)
251-
env_conv, info = ctmrg_iteration(InfiniteSquareNetwork(state), env, alg)
251+
alg_fixed = @set alg.projector_alg.trscheme = FixedSpaceTruncation() # fix spaces during differentiation
252+
env_conv, info = ctmrg_iteration(InfiniteSquareNetwork(state), env, alg_fixed)
252253
env_fixed, signs = gauge_fix(env, env_conv)
253254

254255
# Fix SVD
@@ -279,8 +280,20 @@ end
279280

280281
function _fix_svd_algorithm(alg::SVDAdjoint, signs, info)
281282
# embed gauge signs in larger space to fix gauge of full U and V on truncated subspace
282-
signs_full = map(zip(signs, info.S_full)) do (σ, S_full)
283-
extended_σ = zeros(scalartype(σ), space(S_full))
283+
rowsize, colsize = size(signs, 2), size(signs, 3)
284+
signs_full = map(Iterators.product(1:4, 1:rowsize, 1:colsize)) do (dir, r, c)
285+
σ = signs[dir, r, c]
286+
r_sign, c_sign = if dir == NORTH # take unit cell interdependency of signs into account
287+
r, _prev(c, colsize)
288+
elseif dir == EAST
289+
_prev(r, rowsize), c
290+
elseif dir == SOUTH
291+
r, _next(c, colsize)
292+
elseif dir == WEST
293+
_next(r, rowsize), c
294+
end
295+
extended_space = domain(info.U_full[dir, r_sign, c_sign]) codomain(info.V_full[dir, r_sign, c_sign])
296+
extended_σ = zeros(scalartype(σ), extended_space)
284297
for (c, b) in blocks(extended_σ)
285298
σc = block(σ, c)
286299
kept_dim = size(σc, 1)

test/ctmrg/unitcell.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Test
22
using Random
33
using PEPSKit
4-
using PEPSKit: _prev, _next, ctmrg_iteration
4+
using PEPSKit: _prev, _next, ctmrg_iteration, gauge_fix, _fix_svd_algorithm
55
using TensorKit
66

77
# settings
@@ -22,7 +22,8 @@ function test_unitcell(
2222
env = CTMRGEnv(randn, stype, peps, chis_north, chis_east, chis_south, chis_west)
2323

2424
# apply one CTMRG iteration with fixeds
25-
env′, = ctmrg_iteration(InfiniteSquareNetwork(peps), env, ctm_alg)
25+
env′, info = ctmrg_iteration(InfiniteSquareNetwork(peps), env, ctm_alg)
26+
env″, info = ctmrg_iteration(InfiniteSquareNetwork(peps), env′, ctm_alg) # another iteration to fix spaces
2627

2728
# compute random expecation value to test matching bonds
2829
random_op = LocalOperator(
@@ -35,7 +36,17 @@ function test_unitcell(
3536
]...,
3637
)
3738
@test expectation_value(peps, random_op, env) isa Number
38-
return @test expectation_value(peps, random_op, env′) isa Number
39+
@test expectation_value(peps, random_op, env′) isa Number
40+
41+
# test if gauge fixing routines run through
42+
_, signs = gauge_fix(env′, env″)
43+
@test signs isa Array
44+
return if ctm_alg isa SimultaneousCTMRG # also test :fixed mode gauge fixing for simultaneous CTMRG
45+
svd_alg_fixed_full = _fix_svd_algorithm(SVDAdjoint(; fwd_alg = (; alg = :sdd)), signs, info)
46+
svd_alg_fixed_iter = _fix_svd_algorithm(SVDAdjoint(; fwd_alg = (; alg = :iterative)), signs, info)
47+
@test svd_alg_fixed_full isa SVDAdjoint
48+
@test svd_alg_fixed_iter isa SVDAdjoint
49+
end
3950
end
4051

4152
function random_dualize!(M::AbstractMatrix{<:ElementarySpace})

0 commit comments

Comments
 (0)