Skip to content
19 changes: 16 additions & 3 deletions src/algorithms/optimization/fixed_point_differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ function _rrule(
alg::SimultaneousCTMRG,
)
env, = leading_boundary(envinit, state, alg)
env_conv, info = ctmrg_iteration(InfiniteSquareNetwork(state), env, alg)
alg_fixed = @set alg.projector_alg.trscheme = FixedSpaceTruncation() # fix spaces during differentiation
env_conv, info = ctmrg_iteration(InfiniteSquareNetwork(state), env, alg_fixed)
env_fixed, signs = gauge_fix(env, env_conv)

# Fix SVD
Expand Down Expand Up @@ -279,8 +280,20 @@ end

function _fix_svd_algorithm(alg::SVDAdjoint, signs, info)
# embed gauge signs in larger space to fix gauge of full U and V on truncated subspace
signs_full = map(zip(signs, info.S_full)) do (σ, S_full)
extended_σ = zeros(scalartype(σ), space(S_full))
rowsize, colsize = size(signs, 2), size(signs, 3)
signs_full = map(Iterators.product(1:4, 1:rowsize, 1:colsize)) do (dir, r, c)
σ = signs[dir, r, c]
r_sign, c_sign = if dir == NORTH # take unit cell interdependency of signs into account
r, _prev(c, colsize)
elseif dir == EAST
_prev(r, rowsize), c
elseif dir == SOUTH
r, _next(c, colsize)
elseif dir == WEST
_next(r, rowsize), c
end
extended_space = domain(info.U_full[dir, r_sign, c_sign]) codomain(info.V_full[dir, r_sign, c_sign])
extended_σ = zeros(scalartype(σ), extended_space)
for (c, b) in blocks(extended_σ)
σc = block(σ, c)
kept_dim = size(σc, 1)
Expand Down
17 changes: 14 additions & 3 deletions test/ctmrg/unitcell.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Test
using Random
using PEPSKit
using PEPSKit: _prev, _next, ctmrg_iteration
using PEPSKit: _prev, _next, ctmrg_iteration, gauge_fix, _fix_svd_algorithm
using TensorKit

# settings
Expand All @@ -22,7 +22,8 @@ function test_unitcell(
env = CTMRGEnv(randn, stype, peps, chis_north, chis_east, chis_south, chis_west)

# apply one CTMRG iteration with fixeds
env′, = ctmrg_iteration(InfiniteSquareNetwork(peps), env, ctm_alg)
env′, info = ctmrg_iteration(InfiniteSquareNetwork(peps), env, ctm_alg)
env″, info = ctmrg_iteration(InfiniteSquareNetwork(peps), env′, ctm_alg) # another iteration to fix spaces

# compute random expecation value to test matching bonds
random_op = LocalOperator(
Expand All @@ -35,7 +36,17 @@ function test_unitcell(
]...,
)
@test expectation_value(peps, random_op, env) isa Number
return @test expectation_value(peps, random_op, env′) isa Number
@test expectation_value(peps, random_op, env′) isa Number

# test if gauge fixing routines run through
_, signs = gauge_fix(env′, env″)
@test signs isa Array
return if ctm_alg isa SimultaneousCTMRG # also test :fixed mode gauge fixing for simultaneous CTMRG
svd_alg_fixed_full = _fix_svd_algorithm(SVDAdjoint(; fwd_alg = (; alg = :sdd)), signs, info)
svd_alg_fixed_iter = _fix_svd_algorithm(SVDAdjoint(; fwd_alg = (; alg = :iterative)), signs, info)
@test svd_alg_fixed_full isa SVDAdjoint
@test svd_alg_fixed_iter isa SVDAdjoint
end
end

function random_dualize!(M::AbstractMatrix{<:ElementarySpace})
Expand Down
Loading