Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
88590b0
add anisotropic Simple Update
sanderdemeyer Jun 6, 2025
e2eeb66
fix 3site su
sanderdemeyer Jun 6, 2025
996f9ba
Merge branch 'master' into su_anisotropic
sanderdemeyer Jun 16, 2025
308f3fb
make SiteDependentTruncation a new truncation scheme
sanderdemeyer Jun 16, 2025
72734de
Merge branch 'master' into su_anisotropic
sanderdemeyer Jun 16, 2025
58221d4
Merge branch 'master' into su_anisotropic
sanderdemeyer Jun 17, 2025
213b4f7
update on arguments and constructors
sanderdemeyer Jun 17, 2025
436f8cc
Fix 3-site SU for SiteDependentTruncationScheme
sanderdemeyer Jun 17, 2025
12ff1bd
small documentation update
sanderdemeyer Jun 17, 2025
b92df22
Change name and add rotl90
sanderdemeyer Jun 17, 2025
ff5930e
fix export
sanderdemeyer Jun 17, 2025
578b54e
change arguments in test
sanderdemeyer Jun 17, 2025
e06f568
add test for 3-site cluster update
sanderdemeyer Jun 18, 2025
076b3c8
change test to increase patch coverage
sanderdemeyer Jun 18, 2025
e515976
change name again
sanderdemeyer Jun 18, 2025
83c5ac8
Merge branch 'QuantumKitHub:master' into su_anisotropic
sanderdemeyer Jun 18, 2025
0e405f0
small updates
sanderdemeyer Jun 19, 2025
53d0ac6
Merge branch 'master' into su_anisotropic
sanderdemeyer Jun 19, 2025
acf19cb
rewrite mirror_antidiag
sanderdemeyer Jun 19, 2025
d1d1298
add test on non-square unit cell and fix bug
sanderdemeyer Jun 20, 2025
721cd2b
make the new test a separate test to test bipartite=true
sanderdemeyer Jun 20, 2025
ad21266
Restore old Heisenberg SU-AD test
Yue-Zhengyuan Jun 20, 2025
0ff0f7d
Add test of SU with SiteDependentTruncation
Yue-Zhengyuan Jun 20, 2025
136ec18
Add TODO on bipartite check for SiteDependentTruncation
Yue-Zhengyuan Jun 20, 2025
84a277f
Merge branch 'master' into su_anisotropic
sanderdemeyer Jun 24, 2025
5b98cc9
Update src/algorithms/time_evolution/simpleupdate3site.jl
sanderdemeyer Jun 24, 2025
ec2897c
remove redundant TensorKit. before TruncationScheme
sanderdemeyer Jun 24, 2025
0db17a0
remove redundant TensorKit. before TruncationScheme bis
sanderdemeyer Jun 24, 2025
1707d3c
remove redundant where clauses
lkdvos Jun 24, 2025
6854da9
Add selection support with symbol
lkdvos Jun 24, 2025
4eee139
Change back constant
lkdvos Jun 24, 2025
615fa7f
Also simplify rotation
lkdvos Jun 24, 2025
e70759a
Remove test on Hubbard SU with SiteDependentTruncation
Yue-Zhengyuan Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/PEPSKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ using Compat
using Accessors: @set, @reset
using VectorInterface
import VectorInterface as VI
using TensorKit, KrylovKit, OptimKit, TensorOperations

using TensorKit
using TensorKit: TruncationScheme

using KrylovKit, OptimKit, TensorOperations
using ChainRulesCore, Zygote
using LoggingExtras

Expand Down Expand Up @@ -83,7 +87,8 @@ using .Defaults: set_scheduler!
export set_scheduler!
export SVDAdjoint, FullSVDReverseRule, IterSVD
export CTMRGEnv, SequentialCTMRG, SimultaneousCTMRG
export FixedSpaceTruncation, HalfInfiniteProjector, FullInfiniteProjector
export FixedSpaceTruncation, SiteDependentTruncation
export HalfInfiniteProjector, FullInfiniteProjector
export LocalOperator, physicalspace
export expectation_value, cost_function, product_peps, correlation_length, network_value
export correlator
Expand Down
14 changes: 9 additions & 5 deletions src/algorithms/time_evolution/simpleupdate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct SimpleUpdate
dt::Number
tol::Float64
maxiter::Int
trscheme::TensorKit.TruncationScheme
trscheme::TruncationScheme
end
# TODO: add kwarg constructor and SU Defaults

Expand All @@ -34,7 +34,7 @@ function _su_xbond!(
col::Int,
gate::AbstractTensorMap{T,S,2,2},
peps::InfiniteWeightPEPS,
alg::SimpleUpdate,
trscheme::TruncationScheme,
) where {T<:Number,S<:ElementarySpace}
Nr, Nc = size(peps)
@assert 1 <= row <= Nr && 1 <= col <= Nc
Expand All @@ -47,7 +47,7 @@ function _su_xbond!(
B = _absorb_weights(B, peps.weights, row, cp1, Tuple(1:4), sqrtsB, false)
# apply gate
X, a, b, Y = _qr_bond(A, B)
a, s, b, ϵ = _apply_gate(a, b, gate, alg.trscheme)
a, s, b, ϵ = _apply_gate(a, b, gate, trscheme)
A, B = _qr_bond_undo(X, a, b, Y)
# remove environment weights
_allfalse = ntuple(Returns(false), 3)
Expand Down Expand Up @@ -86,6 +86,9 @@ function su_iter(
# to update them using code for x-weights
if direction == 2
peps2 = mirror_antidiag(peps2)
trscheme = mirror_antidiag(alg.trscheme)
else
trscheme = alg.trscheme
end
if bipartite
for r in 1:2
Expand All @@ -94,7 +97,7 @@ function su_iter(
direction == 1 ? gate : gate_mirrored,
(CartesianIndex(r, 1), CartesianIndex(r, 2)),
)
ϵ = _su_xbond!(r, 1, term, peps2, alg)
ϵ = _su_xbond!(r, 1, term, peps2, truncation_scheme(trscheme, 1, r, 1))
peps2.vertices[rp1, 2] = deepcopy(peps2.vertices[r, 1])
peps2.vertices[rp1, 1] = deepcopy(peps2.vertices[r, 2])
peps2.weights[1, rp1, 2] = deepcopy(peps2.weights[1, r, 1])
Expand All @@ -106,7 +109,7 @@ function su_iter(
direction == 1 ? gate : gate_mirrored,
(CartesianIndex(r, c), CartesianIndex(r, c + 1)),
)
ϵ = _su_xbond!(r, c, term, peps2, alg)
ϵ = _su_xbond!(r, c, term, peps2, truncation_scheme(trscheme, 1, r, c))
end
end
if direction == 2
Expand Down Expand Up @@ -185,6 +188,7 @@ function simpleupdate(
nnonly = is_nearest_neighbour(ham)
use_3site = force_3site || !nnonly
@assert !(bipartite && use_3site) "3-site simple update is incompatible with bipartite lattice."
# TODO: check SiteDependentTruncation is compatible with bipartite structure
if use_3site
return _simpleupdate3site(peps, ham, alg; check_interval)
else
Expand Down
39 changes: 24 additions & 15 deletions src/algorithms/time_evolution/simpleupdate3site.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ The arrows between `Pa`, `s`, `Pb` are
function _proj_from_RL(
r::AbstractTensorMap{T,S,1,1},
l::AbstractTensorMap{T,S,1,1};
trunc::TensorKit.TruncationScheme=notrunc(),
trunc::TruncationScheme=notrunc(),
rev::Bool=false,
) where {T<:Number,S<:ElementarySpace}
rl = r * l
Expand All @@ -235,16 +235,19 @@ end
Given a cluster `Ms` and the pre-calculated `R`, `L` bond matrices,
find all projectors `Pa`, `Pb` and Schmidt weights `wts` on internal bonds.
"""
function _get_allprojs(Ms, Rs, Ls, trunc::TensorKit.TruncationScheme, revs::Vector{Bool})
function _get_allprojs(
Ms, Rs, Ls, trschemes::Vector{E}, revs::Vector{Bool}
) where {E<:TruncationScheme}
N = length(Ms)
@assert length(trschemes) == N - 1
projs_errs = map(1:(N - 1)) do i
trunc2 = if isa(trunc, FixedSpaceTruncation)
trunc = if isa(trschemes[i], FixedSpaceTruncation)
V = space(Ms[i + 1], 1)
truncspace(isdual(V) ? V' : V)
else
trunc
trschemes[i]
end
return _proj_from_RL(Rs[i], Ls[i]; trunc=trunc2, rev=revs[i])
return _proj_from_RL(Rs[i], Ls[i]; trunc, rev=revs[i])
end
Pas = map(Base.Fix2(getindex, 1), projs_errs)
wts = map(Base.Fix2(getindex, 2), projs_errs)
Expand All @@ -258,10 +261,10 @@ end
Find projectors to truncate internal bonds of the cluster `Ms`
"""
function _cluster_truncate!(
Ms::Vector{T}, trunc::TensorKit.TruncationScheme, revs::Vector{Bool}
) where {T<:PEPSTensor}
Ms::Vector{T}, trschemes::Vector{E}, revs::Vector{Bool}
) where {T<:PEPSTensor,E<:TruncationScheme}
Rs, Ls = _get_allRLs(Ms)
Pas, Pbs, wts, ϵs = _get_allprojs(Ms, Rs, Ls, trunc, revs)
Pas, Pbs, wts, ϵs = _get_allprojs(Ms, Rs, Ls, trschemes, revs)
# apply projectors
# M1 -- (Pa1,wt1,Pb1) -- M2 -- (Pa2,wt2,Pb2) -- M3
for (i, (Pa, Pb)) in enumerate(zip(Pas, Pbs))
Expand Down Expand Up @@ -322,13 +325,13 @@ In the cluster, the axes of each PEPSTensor are reordered as
```
"""
function apply_gatempo!(
Ms::Vector{T1}, gs::Vector{T2}; trunc::TensorKit.TruncationScheme
) where {T1<:PEPSTensor,T2<:AbstractTensorMap}
Ms::Vector{T1}, gs::Vector{T2}; trschemes::Vector{E}
) where {T1<:PEPSTensor,T2<:AbstractTensorMap,E<:TruncationScheme}
@assert length(Ms) == length(gs)
revs = [isdual(space(M, 1)) for M in Ms[2:end]]
@assert !all(revs)
_apply_gatempo!(Ms, gs)
wts, ϵs, = _cluster_truncate!(Ms, trunc, revs)
wts, ϵs, = _cluster_truncate!(Ms, trschemes, revs)
return wts, ϵs
end

Expand Down Expand Up @@ -373,8 +376,8 @@ function get_3site_se(peps::InfiniteWeightPEPS, row::Int, col::Int)
end

function _su3site_se!(
row::Int, col::Int, gs::Vector{T}, peps::InfiniteWeightPEPS, alg::SimpleUpdate
) where {T<:AbstractTensorMap}
row::Int, col::Int, gs::Vector{T}, peps::InfiniteWeightPEPS, trschemes::Vector{E}
) where {T<:AbstractTensorMap,E<:TruncationScheme}
Nr, Nc = size(peps)
@assert 1 <= row <= Nr && 1 <= col <= Nc
rm1, cp1 = _prev(row, Nr), _next(col, Nc)
Expand All @@ -384,7 +387,7 @@ function _su3site_se!(
coords = ((row, col), (row, cp1), (rm1, cp1))
# weights in the cluster
wt_idxs = ((1, row, col), (2, row, cp1))
wts, ϵ = apply_gatempo!(Ms, gs; trunc=alg.trscheme)
wts, ϵ = apply_gatempo!(Ms, gs; trschemes)
for (wt, wt_idx) in zip(wts, wt_idxs)
peps.weights[CartesianIndex(wt_idx)] = wt / norm(wt, Inf)
end
Expand Down Expand Up @@ -414,13 +417,19 @@ function su3site_iter(
),
)
peps2 = deepcopy(peps)
trscheme = alg.trscheme
for i in 1:4
for site in CartesianIndices(peps2.vertices)
r, c = site[1], site[2]
gs = gatempos[i][r, c]
_su3site_se!(r, c, gs, peps2, alg)
trschemes = [
truncation_scheme(trscheme, 1, r, c)
truncation_scheme(trscheme, 2, r, _next(c, size(peps2)[2]))
]
_su3site_se!(r, c, gs, peps2, trschemes)
end
peps2 = rotl90(peps2)
trscheme = rotl90(trscheme)
end
return peps2
end
Expand Down
4 changes: 2 additions & 2 deletions src/algorithms/truncation/bond_truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ $(TYPEDFIELDS)

The truncation algorithm can be constructed from the following keyword arguments:

* `trscheme::TensorKit.TruncationScheme`: SVD truncation scheme when initilizing the truncated tensors connected by the bond.
* `trscheme::TruncationScheme`: SVD truncation scheme when initilizing the truncated tensors connected by the bond.
* `maxiter::Int=50` : Maximal number of ALS iterations.
* `tol::Float64=1e-15` : ALS converges when fidelity change between two FET iterations is smaller than `tol`.
* `check_interval::Int=0` : Set number of iterations to print information. Output is suppressed when `check_interval <= 0`.
"""
@kwdef struct ALSTruncation
trscheme::TensorKit.TruncationScheme
trscheme::TruncationScheme
maxiter::Int = 50
tol::Float64 = 1e-15
check_interval::Int = 0
Expand Down
4 changes: 2 additions & 2 deletions src/algorithms/truncation/fullenv_truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ $(TYPEDFIELDS)

The truncation algorithm can be constructed from the following keyword arguments:

* `trscheme::TensorKit.TruncationScheme` : SVD truncation scheme when optimizing the new bond matrix.
* `trscheme::TruncationScheme` : SVD truncation scheme when optimizing the new bond matrix.
* `maxiter::Int=50` : Maximal number of FET iterations.
* `tol::Float64=1e-15` : FET converges when fidelity change between two FET iterations is smaller than `tol`.
* `trunc_init::Bool=true` : Controls whether the initialization of the new bond matrix is obtained from truncated SVD of the old bond matrix.
Expand All @@ -24,7 +24,7 @@ The truncation algorithm can be constructed from the following keyword arguments
* [Glen Evenbly, Phys. Rev. B 98, 085155 (2018)](@cite evenbly_gauge_2018).
"""
@kwdef struct FullEnvTruncation
trscheme::TensorKit.TruncationScheme
trscheme::TruncationScheme
maxiter::Int = 50
tol::Float64 = 1e-15
trunc_init::Bool = true
Expand Down
69 changes: 68 additions & 1 deletion src/algorithms/truncation/truncationschemes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
is performed fixed. Since different environment directions and unit cell entries might
have different spaces, this truncation style is different from `TruncationSpace`.
"""
struct FixedSpaceTruncation <: TensorKit.TruncationScheme end
struct FixedSpaceTruncation <: TruncationScheme end

struct SiteDependentTruncation{T<:TruncationScheme} <: TruncationScheme
trschemes::Array{T,3}
end

const TRUNCATION_SCHEME_SYMBOLS = IdDict{Symbol,Type{<:TruncationScheme}}(
:fixedspace => FixedSpaceTruncation,
Expand All @@ -14,6 +18,7 @@
:truncdim => TensorKit.TruncationDimension,
:truncspace => TensorKit.TruncationSpace,
:truncbelow => TensorKit.TruncationCutoff,
:sitedependent => SiteDependentTruncation,
)

# Should be TruncationScheme but rename to avoid type piracy
Expand All @@ -25,3 +30,65 @@

return isnothing(η) ? alg_type() : alg_type(η)
end

function truncation_scheme(
trscheme::TruncationScheme, direction::Int, row::Int, col::Int; kwargs...
)
return trscheme
end

function truncation_scheme(
trscheme::SiteDependentTruncation, direction::Int, row::Int, col::Int;
)
return trscheme.trschemes[direction, row, col]
end

# Mirror a TruncationScheme by its anti-diagonal line.
# When the number of directions is 2, it swaps the first and second direction, consistent with xbonds and ybonds, respectively.
# When the number of directions is 4, it swaps the first and second, and third and fourth directions, consistent with the order NORTH, EAST, SOUTH, WEST.
mirror_antidiag(trscheme::TruncationScheme) = trscheme
function mirror_antidiag(trscheme::SiteDependentTruncation)
directions = size(trscheme.trschemes)[1]
if directions == 2
trschemes_mirrored = stack(
(
mirror_antidiag(trscheme.trschemes[EAST, :, :]),
mirror_antidiag(trscheme.trschemes[NORTH, :, :]),
);
dims=1,
)
elseif directions == 4
trschemes_mirrored = stack((

Check warning on line 61 in src/algorithms/truncation/truncationschemes.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/truncation/truncationschemes.jl#L60-L61

Added lines #L60 - L61 were not covered by tests
mirror_antidiag(trscheme.trschemes[EAST, :, :]),
mirror_antidiag(trscheme.trschemes[NORTH, :, :]),
mirror_antidiag(trscheme.trschemes[WEST, :, :]),
mirror_antidiag(trscheme.trschemes[SOUTH, :, :]),
))
else
error("Unsupported number of directions for mirror_antidiag: $directions")

Check warning on line 68 in src/algorithms/truncation/truncationschemes.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/truncation/truncationschemes.jl#L68

Added line #L68 was not covered by tests
end
return SiteDependentTruncation(trschemes_mirrored)
end

# TODO: type piracy
Base.rotl90(trscheme::TruncationScheme) = trscheme

function Base.rotl90(trscheme::SiteDependentTruncation)
directions, rows, cols = size(trscheme.trschemes)
trschemes_rotated = similar(trscheme.trschemes, directions, cols, rows)

if directions == 2
trschemes_rotated[NORTH, :, :] = circshift(
rotl90(trscheme.trschemes[EAST, :, :]), (0, -1)
)
trschemes_rotated[EAST, :, :] = rotl90(trscheme.trschemes[NORTH, :, :])
elseif directions == 4
for dir in 1:4
dir′ = _prev(dir, 4)
trschemes_rotated[dir′, :, :] = rotl90(trscheme.trschemes[dir, :, :])
end

Check warning on line 89 in src/algorithms/truncation/truncationschemes.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/truncation/truncationschemes.jl#L85-L89

Added lines #L85 - L89 were not covered by tests
else
throw(ArgumentError("Unsupported number of directions for rotl90: $directions"))

Check warning on line 91 in src/algorithms/truncation/truncationschemes.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/truncation/truncationschemes.jl#L91

Added line #L91 was not covered by tests
end
return SiteDependentTruncation(trschemes_rotated)
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ end
@time @safetestset "Cluster truncation with projectors" begin
include("timeevol/cluster_projectors.jl")
end
@time @safetestset "Time evolution with site-dependent truncation" begin
include("timeevol/sitedep_truncation.jl")
end
end
if GROUP == "ALL" || GROUP == "UTILITY"
@time @safetestset "LocalOperator" begin
Expand Down
4 changes: 2 additions & 2 deletions test/timeevol/cluster_projectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Vspaces = [
revs = [isdual(space(M, 1)) for M in Ms1[2:end]]
# no truncation
Ms2 = deepcopy(Ms1)
wts2, ϵs, = _cluster_truncate!(Ms2, FixedSpaceTruncation(), revs)
wts2, ϵs, = _cluster_truncate!(Ms2, fill(FixedSpaceTruncation(), N-1), revs)
@test all((ϵ == 0) for ϵ in ϵs)
absorb_wts_cluster!(Ms2, wts2)
for (i, M) in enumerate(Ms2)
Expand All @@ -41,7 +41,7 @@ Vspaces = [
@test all(lorths) && all(rorths)
# truncation on one bond
Ms3 = deepcopy(Ms1)
wts3, ϵs, = _cluster_truncate!(Ms3, truncspace(Vns), revs)
wts3, ϵs, = _cluster_truncate!(Ms3, fill(truncspace(Vns), N-1), revs)
@test all((i == n) || (ϵ == 0) for (i, ϵ) in enumerate(ϵs))
absorb_wts_cluster!(Ms3, wts3)
for (i, M) in enumerate(Ms3)
Expand Down
54 changes: 54 additions & 0 deletions test/timeevol/sitedep_truncation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using Test
using LinearAlgebra
using Random
using TensorKit
using PEPSKit
using PEPSKit: NORTH, EAST

function get_bonddims(wpeps::InfiniteWeightPEPS)
xdims = collect(dim(domain(t, EAST)) for t in wpeps.vertices)
ydims = collect(dim(domain(t, NORTH)) for t in wpeps.vertices)
return stack([xdims, ydims]; dims=1)
end

@testset "Simple update: bipartite 2-site" begin
Nr, Nc = 2, 2
ham = real(heisenberg_XYZ(InfiniteSquare(Nr, Nc); Jx=1.0, Jy=1.0, Jz=1.0))
Random.seed!(100)
wpeps0 = InfiniteWeightPEPS(rand, Float64, ℂ^2, ℂ^10; unitcell=(Nr, Nc))
normalize!.(wpeps0.vertices, Inf)
# set trscheme to be compatible with bipartite structure
bonddims = stack([[6 4; 4 6], [5 7; 7 5]]; dims=1)
trscheme = SiteDependentTruncation(collect(truncdim(d) for d in bonddims))
alg = SimpleUpdate(1e-2, 1e-14, 4, trscheme)
wpeps, = simpleupdate(wpeps0, ham, alg; bipartite=true)
@test get_bonddims(wpeps) == bonddims
# check bipartite structure is preserved
for col in 1:2
cp1 = PEPSKit._next(col, 2)
@test (
wpeps.vertices[1, col] == wpeps.vertices[2, cp1] &&
wpeps.weights[1, 1, col] == wpeps.weights[1, 2, cp1] &&
wpeps.weights[2, 1, col] == wpeps.weights[2, 2, cp1]
)
end
end

@testset "Simple update: generic 2-site and 3-site" begin
Nr, Nc = 3, 4
ham = real(heisenberg_XYZ(InfiniteSquare(Nr, Nc); Jx=1.0, Jy=1.0, Jz=1.0))
Random.seed!(100)
wpeps0 = InfiniteWeightPEPS(rand, Float64, ℂ^2, ℂ^10; unitcell=(Nr, Nc))
normalize!.(wpeps0.vertices, Inf)
# Site dependent truncation
bonddims = rand(2:8, 2, Nr, Nc)
@show bonddims
trscheme = SiteDependentTruncation(collect(truncdim(d) for d in bonddims))
alg = SimpleUpdate(1e-2, 1e-14, 2, trscheme)
# 2-site SU
wpeps, = simpleupdate(wpeps0, ham, alg; bipartite=false)
@test get_bonddims(wpeps) == bonddims
# 3-site SU
wpeps, = simpleupdate(wpeps0, ham, alg; bipartite=false, force_3site=true)
@test get_bonddims(wpeps) == bonddims
end