Skip to content

Commit 66f56f1

Browse files
sanderdemeyerYue-Zhengyuanlkdvos
authored
Add SiteDependentTruncation (#211)
* add anisotropic Simple Update * fix 3site su * make SiteDependentTruncation a new truncation scheme Simple update was changed accordingly, CTMRG not yet * update on arguments and constructors * Fix 3-site SU for SiteDependentTruncationScheme * small documentation update * Change name and add rotl90 * fix export * change arguments in test * add test for 3-site cluster update * change test to increase patch coverage * change name again * small updates * rewrite mirror_antidiag * add test on non-square unit cell and fix bug * make the new test a separate test to test bipartite=true * Restore old Heisenberg SU-AD test * Add test of SU with SiteDependentTruncation * Add TODO on bipartite check for SiteDependentTruncation * Update src/algorithms/time_evolution/simpleupdate3site.jl Co-authored-by: Yue Zhengyuan <yuezy1997@icloud.com> * remove redundant TensorKit. before TruncationScheme * remove redundant TensorKit. before TruncationScheme bis * remove redundant where clauses * Add selection support with symbol * Change back constant * Also simplify rotation * Remove test on Hubbard SU with SiteDependentTruncation --------- Co-authored-by: Yue Zhengyuan <yuezy1997@icloud.com> Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 14b3630 commit 66f56f1

File tree

9 files changed

+171
-29
lines changed

9 files changed

+171
-29
lines changed

src/PEPSKit.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ using Compat
55
using Accessors: @set, @reset
66
using VectorInterface
77
import VectorInterface as VI
8-
using TensorKit, KrylovKit, OptimKit, TensorOperations
8+
9+
using TensorKit
10+
using TensorKit: TruncationScheme
11+
12+
using KrylovKit, OptimKit, TensorOperations
913
using ChainRulesCore, Zygote
1014
using LoggingExtras
1115

@@ -83,7 +87,8 @@ using .Defaults: set_scheduler!
8387
export set_scheduler!
8488
export SVDAdjoint, FullSVDReverseRule, IterSVD
8589
export CTMRGEnv, SequentialCTMRG, SimultaneousCTMRG
86-
export FixedSpaceTruncation, HalfInfiniteProjector, FullInfiniteProjector
90+
export FixedSpaceTruncation, SiteDependentTruncation
91+
export HalfInfiniteProjector, FullInfiniteProjector
8792
export LocalOperator, physicalspace
8893
export expectation_value, cost_function, product_peps, correlation_length, network_value
8994
export correlator

src/algorithms/time_evolution/simpleupdate.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ struct SimpleUpdate
1212
dt::Number
1313
tol::Float64
1414
maxiter::Int
15-
trscheme::TensorKit.TruncationScheme
15+
trscheme::TruncationScheme
1616
end
1717
# TODO: add kwarg constructor and SU Defaults
1818

@@ -34,7 +34,7 @@ function _su_xbond!(
3434
col::Int,
3535
gate::AbstractTensorMap{T,S,2,2},
3636
peps::InfiniteWeightPEPS,
37-
alg::SimpleUpdate,
37+
trscheme::TruncationScheme,
3838
) where {T<:Number,S<:ElementarySpace}
3939
Nr, Nc = size(peps)
4040
@assert 1 <= row <= Nr && 1 <= col <= Nc
@@ -47,7 +47,7 @@ function _su_xbond!(
4747
B = _absorb_weights(B, peps.weights, row, cp1, Tuple(1:4), sqrtsB, false)
4848
# apply gate
4949
X, a, b, Y = _qr_bond(A, B)
50-
a, s, b, ϵ = _apply_gate(a, b, gate, alg.trscheme)
50+
a, s, b, ϵ = _apply_gate(a, b, gate, trscheme)
5151
A, B = _qr_bond_undo(X, a, b, Y)
5252
# remove environment weights
5353
_allfalse = ntuple(Returns(false), 3)
@@ -86,6 +86,9 @@ function su_iter(
8686
# to update them using code for x-weights
8787
if direction == 2
8888
peps2 = mirror_antidiag(peps2)
89+
trscheme = mirror_antidiag(alg.trscheme)
90+
else
91+
trscheme = alg.trscheme
8992
end
9093
if bipartite
9194
for r in 1:2
@@ -94,7 +97,7 @@ function su_iter(
9497
direction == 1 ? gate : gate_mirrored,
9598
(CartesianIndex(r, 1), CartesianIndex(r, 2)),
9699
)
97-
ϵ = _su_xbond!(r, 1, term, peps2, alg)
100+
ϵ = _su_xbond!(r, 1, term, peps2, truncation_scheme(trscheme, 1, r, 1))
98101
peps2.vertices[rp1, 2] = deepcopy(peps2.vertices[r, 1])
99102
peps2.vertices[rp1, 1] = deepcopy(peps2.vertices[r, 2])
100103
peps2.weights[1, rp1, 2] = deepcopy(peps2.weights[1, r, 1])
@@ -106,7 +109,7 @@ function su_iter(
106109
direction == 1 ? gate : gate_mirrored,
107110
(CartesianIndex(r, c), CartesianIndex(r, c + 1)),
108111
)
109-
ϵ = _su_xbond!(r, c, term, peps2, alg)
112+
ϵ = _su_xbond!(r, c, term, peps2, truncation_scheme(trscheme, 1, r, c))
110113
end
111114
end
112115
if direction == 2
@@ -185,6 +188,7 @@ function simpleupdate(
185188
nnonly = is_nearest_neighbour(ham)
186189
use_3site = force_3site || !nnonly
187190
@assert !(bipartite && use_3site) "3-site simple update is incompatible with bipartite lattice."
191+
# TODO: check SiteDependentTruncation is compatible with bipartite structure
188192
if use_3site
189193
return _simpleupdate3site(peps, ham, alg; check_interval)
190194
else

src/algorithms/time_evolution/simpleupdate3site.jl

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ The arrows between `Pa`, `s`, `Pb` are
215215
function _proj_from_RL(
216216
r::AbstractTensorMap{T,S,1,1},
217217
l::AbstractTensorMap{T,S,1,1};
218-
trunc::TensorKit.TruncationScheme=notrunc(),
218+
trunc::TruncationScheme=notrunc(),
219219
rev::Bool=false,
220220
) where {T<:Number,S<:ElementarySpace}
221221
rl = r * l
@@ -235,16 +235,19 @@ end
235235
Given a cluster `Ms` and the pre-calculated `R`, `L` bond matrices,
236236
find all projectors `Pa`, `Pb` and Schmidt weights `wts` on internal bonds.
237237
"""
238-
function _get_allprojs(Ms, Rs, Ls, trunc::TensorKit.TruncationScheme, revs::Vector{Bool})
238+
function _get_allprojs(
239+
Ms, Rs, Ls, trschemes::Vector{E}, revs::Vector{Bool}
240+
) where {E<:TruncationScheme}
239241
N = length(Ms)
242+
@assert length(trschemes) == N - 1
240243
projs_errs = map(1:(N - 1)) do i
241-
trunc2 = if isa(trunc, FixedSpaceTruncation)
244+
trunc = if isa(trschemes[i], FixedSpaceTruncation)
242245
V = space(Ms[i + 1], 1)
243246
truncspace(isdual(V) ? V' : V)
244247
else
245-
trunc
248+
trschemes[i]
246249
end
247-
return _proj_from_RL(Rs[i], Ls[i]; trunc=trunc2, rev=revs[i])
250+
return _proj_from_RL(Rs[i], Ls[i]; trunc, rev=revs[i])
248251
end
249252
Pas = map(Base.Fix2(getindex, 1), projs_errs)
250253
wts = map(Base.Fix2(getindex, 2), projs_errs)
@@ -258,10 +261,10 @@ end
258261
Find projectors to truncate internal bonds of the cluster `Ms`
259262
"""
260263
function _cluster_truncate!(
261-
Ms::Vector{T}, trunc::TensorKit.TruncationScheme, revs::Vector{Bool}
262-
) where {T<:PEPSTensor}
264+
Ms::Vector{T}, trschemes::Vector{E}, revs::Vector{Bool}
265+
) where {T<:PEPSTensor,E<:TruncationScheme}
263266
Rs, Ls = _get_allRLs(Ms)
264-
Pas, Pbs, wts, ϵs = _get_allprojs(Ms, Rs, Ls, trunc, revs)
267+
Pas, Pbs, wts, ϵs = _get_allprojs(Ms, Rs, Ls, trschemes, revs)
265268
# apply projectors
266269
# M1 -- (Pa1,wt1,Pb1) -- M2 -- (Pa2,wt2,Pb2) -- M3
267270
for (i, (Pa, Pb)) in enumerate(zip(Pas, Pbs))
@@ -322,13 +325,13 @@ In the cluster, the axes of each PEPSTensor are reordered as
322325
```
323326
"""
324327
function apply_gatempo!(
325-
Ms::Vector{T1}, gs::Vector{T2}; trunc::TensorKit.TruncationScheme
326-
) where {T1<:PEPSTensor,T2<:AbstractTensorMap}
328+
Ms::Vector{T1}, gs::Vector{T2}; trschemes::Vector{E}
329+
) where {T1<:PEPSTensor,T2<:AbstractTensorMap,E<:TruncationScheme}
327330
@assert length(Ms) == length(gs)
328331
revs = [isdual(space(M, 1)) for M in Ms[2:end]]
329332
@assert !all(revs)
330333
_apply_gatempo!(Ms, gs)
331-
wts, ϵs, = _cluster_truncate!(Ms, trunc, revs)
334+
wts, ϵs, = _cluster_truncate!(Ms, trschemes, revs)
332335
return wts, ϵs
333336
end
334337
@@ -373,8 +376,8 @@ function get_3site_se(peps::InfiniteWeightPEPS, row::Int, col::Int)
373376
end
374377
375378
function _su3site_se!(
376-
row::Int, col::Int, gs::Vector{T}, peps::InfiniteWeightPEPS, alg::SimpleUpdate
377-
) where {T<:AbstractTensorMap}
379+
row::Int, col::Int, gs::Vector{T}, peps::InfiniteWeightPEPS, trschemes::Vector{E}
380+
) where {T<:AbstractTensorMap,E<:TruncationScheme}
378381
Nr, Nc = size(peps)
379382
@assert 1 <= row <= Nr && 1 <= col <= Nc
380383
rm1, cp1 = _prev(row, Nr), _next(col, Nc)
@@ -384,7 +387,7 @@ function _su3site_se!(
384387
coords = ((row, col), (row, cp1), (rm1, cp1))
385388
# weights in the cluster
386389
wt_idxs = ((1, row, col), (2, row, cp1))
387-
wts, ϵ = apply_gatempo!(Ms, gs; trunc=alg.trscheme)
390+
wts, ϵ = apply_gatempo!(Ms, gs; trschemes)
388391
for (wt, wt_idx) in zip(wts, wt_idxs)
389392
peps.weights[CartesianIndex(wt_idx)] = wt / norm(wt, Inf)
390393
end
@@ -414,13 +417,19 @@ function su3site_iter(
414417
),
415418
)
416419
peps2 = deepcopy(peps)
420+
trscheme = alg.trscheme
417421
for i in 1:4
418422
for site in CartesianIndices(peps2.vertices)
419423
r, c = site[1], site[2]
420424
gs = gatempos[i][r, c]
421-
_su3site_se!(r, c, gs, peps2, alg)
425+
trschemes = [
426+
truncation_scheme(trscheme, 1, r, c)
427+
truncation_scheme(trscheme, 2, r, _next(c, size(peps2)[2]))
428+
]
429+
_su3site_se!(r, c, gs, peps2, trschemes)
422430
end
423431
peps2 = rotl90(peps2)
432+
trscheme = rotl90(trscheme)
424433
end
425434
return peps2
426435
end

src/algorithms/truncation/bond_truncation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ $(TYPEDFIELDS)
1313
1414
The truncation algorithm can be constructed from the following keyword arguments:
1515
16-
* `trscheme::TensorKit.TruncationScheme`: SVD truncation scheme when initilizing the truncated tensors connected by the bond.
16+
* `trscheme::TruncationScheme`: SVD truncation scheme when initilizing the truncated tensors connected by the bond.
1717
* `maxiter::Int=50` : Maximal number of ALS iterations.
1818
* `tol::Float64=1e-15` : ALS converges when fidelity change between two FET iterations is smaller than `tol`.
1919
* `check_interval::Int=0` : Set number of iterations to print information. Output is suppressed when `check_interval <= 0`.
2020
"""
2121
@kwdef struct ALSTruncation
22-
trscheme::TensorKit.TruncationScheme
22+
trscheme::TruncationScheme
2323
maxiter::Int = 50
2424
tol::Float64 = 1e-15
2525
check_interval::Int = 0

src/algorithms/truncation/fullenv_truncation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ $(TYPEDFIELDS)
1313
1414
The truncation algorithm can be constructed from the following keyword arguments:
1515
16-
* `trscheme::TensorKit.TruncationScheme` : SVD truncation scheme when optimizing the new bond matrix.
16+
* `trscheme::TruncationScheme` : SVD truncation scheme when optimizing the new bond matrix.
1717
* `maxiter::Int=50` : Maximal number of FET iterations.
1818
* `tol::Float64=1e-15` : FET converges when fidelity change between two FET iterations is smaller than `tol`.
1919
* `trunc_init::Bool=true` : Controls whether the initialization of the new bond matrix is obtained from truncated SVD of the old bond matrix.
@@ -24,7 +24,7 @@ The truncation algorithm can be constructed from the following keyword arguments
2424
* [Glen Evenbly, Phys. Rev. B 98, 085155 (2018)](@cite evenbly_gauge_2018).
2525
"""
2626
@kwdef struct FullEnvTruncation
27-
trscheme::TensorKit.TruncationScheme
27+
trscheme::TruncationScheme
2828
maxiter::Int = 50
2929
tol::Float64 = 1e-15
3030
trunc_init::Bool = true

src/algorithms/truncation/truncationschemes.jl

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ CTMRG specific truncation scheme for `tsvd` which keeps the bond space on which
55
is performed fixed. Since different environment directions and unit cell entries might
66
have different spaces, this truncation style is different from `TruncationSpace`.
77
"""
8-
struct FixedSpaceTruncation <: TensorKit.TruncationScheme end
8+
struct FixedSpaceTruncation <: TruncationScheme end
9+
10+
struct SiteDependentTruncation{T<:TruncationScheme} <: TruncationScheme
11+
trschemes::Array{T,3}
12+
end
913

1014
const TRUNCATION_SCHEME_SYMBOLS = IdDict{Symbol,Type{<:TruncationScheme}}(
1115
:fixedspace => FixedSpaceTruncation,
@@ -14,6 +18,7 @@ const TRUNCATION_SCHEME_SYMBOLS = IdDict{Symbol,Type{<:TruncationScheme}}(
1418
:truncdim => TensorKit.TruncationDimension,
1519
:truncspace => TensorKit.TruncationSpace,
1620
:truncbelow => TensorKit.TruncationCutoff,
21+
:sitedependent => SiteDependentTruncation,
1722
)
1823

1924
# Should be TruncationScheme but rename to avoid type piracy
@@ -25,3 +30,65 @@ function _TruncationScheme(; alg=Defaults.trscheme, η=nothing)
2530

2631
return isnothing(η) ? alg_type() : alg_type(η)
2732
end
33+
34+
function truncation_scheme(
35+
trscheme::TruncationScheme, direction::Int, row::Int, col::Int; kwargs...
36+
)
37+
return trscheme
38+
end
39+
40+
function truncation_scheme(
41+
trscheme::SiteDependentTruncation, direction::Int, row::Int, col::Int;
42+
)
43+
return trscheme.trschemes[direction, row, col]
44+
end
45+
46+
# Mirror a TruncationScheme by its anti-diagonal line.
47+
# When the number of directions is 2, it swaps the first and second direction, consistent with xbonds and ybonds, respectively.
48+
# When the number of directions is 4, it swaps the first and second, and third and fourth directions, consistent with the order NORTH, EAST, SOUTH, WEST.
49+
mirror_antidiag(trscheme::TruncationScheme) = trscheme
50+
function mirror_antidiag(trscheme::SiteDependentTruncation)
51+
directions = size(trscheme.trschemes)[1]
52+
if directions == 2
53+
trschemes_mirrored = stack(
54+
(
55+
mirror_antidiag(trscheme.trschemes[EAST, :, :]),
56+
mirror_antidiag(trscheme.trschemes[NORTH, :, :]),
57+
);
58+
dims=1,
59+
)
60+
elseif directions == 4
61+
trschemes_mirrored = stack((
62+
mirror_antidiag(trscheme.trschemes[EAST, :, :]),
63+
mirror_antidiag(trscheme.trschemes[NORTH, :, :]),
64+
mirror_antidiag(trscheme.trschemes[WEST, :, :]),
65+
mirror_antidiag(trscheme.trschemes[SOUTH, :, :]),
66+
))
67+
else
68+
error("Unsupported number of directions for mirror_antidiag: $directions")
69+
end
70+
return SiteDependentTruncation(trschemes_mirrored)
71+
end
72+
73+
# TODO: type piracy
74+
Base.rotl90(trscheme::TruncationScheme) = trscheme
75+
76+
function Base.rotl90(trscheme::SiteDependentTruncation)
77+
directions, rows, cols = size(trscheme.trschemes)
78+
trschemes_rotated = similar(trscheme.trschemes, directions, cols, rows)
79+
80+
if directions == 2
81+
trschemes_rotated[NORTH, :, :] = circshift(
82+
rotl90(trscheme.trschemes[EAST, :, :]), (0, -1)
83+
)
84+
trschemes_rotated[EAST, :, :] = rotl90(trscheme.trschemes[NORTH, :, :])
85+
elseif directions == 4
86+
for dir in 1:4
87+
dir′ = _prev(dir, 4)
88+
trschemes_rotated[dir′, :, :] = rotl90(trscheme.trschemes[dir, :, :])
89+
end
90+
else
91+
throw(ArgumentError("Unsupported number of directions for rotl90: $directions"))
92+
end
93+
return SiteDependentTruncation(trschemes_rotated)
94+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ end
6161
@time @safetestset "Cluster truncation with projectors" begin
6262
include("timeevol/cluster_projectors.jl")
6363
end
64+
@time @safetestset "Time evolution with site-dependent truncation" begin
65+
include("timeevol/sitedep_truncation.jl")
66+
end
6467
end
6568
if GROUP == "ALL" || GROUP == "UTILITY"
6669
@time @safetestset "LocalOperator" begin

test/timeevol/cluster_projectors.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Vspaces = [
3030
revs = [isdual(space(M, 1)) for M in Ms1[2:end]]
3131
# no truncation
3232
Ms2 = deepcopy(Ms1)
33-
wts2, ϵs, = _cluster_truncate!(Ms2, FixedSpaceTruncation(), revs)
33+
wts2, ϵs, = _cluster_truncate!(Ms2, fill(FixedSpaceTruncation(), N-1), revs)
3434
@test all((ϵ == 0) for ϵ in ϵs)
3535
absorb_wts_cluster!(Ms2, wts2)
3636
for (i, M) in enumerate(Ms2)
@@ -41,7 +41,7 @@ Vspaces = [
4141
@test all(lorths) && all(rorths)
4242
# truncation on one bond
4343
Ms3 = deepcopy(Ms1)
44-
wts3, ϵs, = _cluster_truncate!(Ms3, truncspace(Vns), revs)
44+
wts3, ϵs, = _cluster_truncate!(Ms3, fill(truncspace(Vns), N-1), revs)
4545
@test all((i == n) ||== 0) for (i, ϵ) in enumerate(ϵs))
4646
absorb_wts_cluster!(Ms3, wts3)
4747
for (i, M) in enumerate(Ms3)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
using Test
2+
using LinearAlgebra
3+
using Random
4+
using TensorKit
5+
using PEPSKit
6+
using PEPSKit: NORTH, EAST
7+
8+
function get_bonddims(wpeps::InfiniteWeightPEPS)
9+
xdims = collect(dim(domain(t, EAST)) for t in wpeps.vertices)
10+
ydims = collect(dim(domain(t, NORTH)) for t in wpeps.vertices)
11+
return stack([xdims, ydims]; dims=1)
12+
end
13+
14+
@testset "Simple update: bipartite 2-site" begin
15+
Nr, Nc = 2, 2
16+
ham = real(heisenberg_XYZ(InfiniteSquare(Nr, Nc); Jx=1.0, Jy=1.0, Jz=1.0))
17+
Random.seed!(100)
18+
wpeps0 = InfiniteWeightPEPS(rand, Float64, ℂ^2, ℂ^10; unitcell=(Nr, Nc))
19+
normalize!.(wpeps0.vertices, Inf)
20+
# set trscheme to be compatible with bipartite structure
21+
bonddims = stack([[6 4; 4 6], [5 7; 7 5]]; dims=1)
22+
trscheme = SiteDependentTruncation(collect(truncdim(d) for d in bonddims))
23+
alg = SimpleUpdate(1e-2, 1e-14, 4, trscheme)
24+
wpeps, = simpleupdate(wpeps0, ham, alg; bipartite=true)
25+
@test get_bonddims(wpeps) == bonddims
26+
# check bipartite structure is preserved
27+
for col in 1:2
28+
cp1 = PEPSKit._next(col, 2)
29+
@test (
30+
wpeps.vertices[1, col] == wpeps.vertices[2, cp1] &&
31+
wpeps.weights[1, 1, col] == wpeps.weights[1, 2, cp1] &&
32+
wpeps.weights[2, 1, col] == wpeps.weights[2, 2, cp1]
33+
)
34+
end
35+
end
36+
37+
@testset "Simple update: generic 2-site and 3-site" begin
38+
Nr, Nc = 3, 4
39+
ham = real(heisenberg_XYZ(InfiniteSquare(Nr, Nc); Jx=1.0, Jy=1.0, Jz=1.0))
40+
Random.seed!(100)
41+
wpeps0 = InfiniteWeightPEPS(rand, Float64, ℂ^2, ℂ^10; unitcell=(Nr, Nc))
42+
normalize!.(wpeps0.vertices, Inf)
43+
# Site dependent truncation
44+
bonddims = rand(2:8, 2, Nr, Nc)
45+
@show bonddims
46+
trscheme = SiteDependentTruncation(collect(truncdim(d) for d in bonddims))
47+
alg = SimpleUpdate(1e-2, 1e-14, 2, trscheme)
48+
# 2-site SU
49+
wpeps, = simpleupdate(wpeps0, ham, alg; bipartite=false)
50+
@test get_bonddims(wpeps) == bonddims
51+
# 3-site SU
52+
wpeps, = simpleupdate(wpeps0, ham, alg; bipartite=false, force_3site=true)
53+
@test get_bonddims(wpeps) == bonddims
54+
end

0 commit comments

Comments
 (0)