Skip to content

Commit 01031f9

Browse files
committed
Iterative VUMPS and VOMPS
1 parent c8ce23a commit 01031f9

File tree

4 files changed

+275
-191
lines changed

4 files changed

+275
-191
lines changed

src/algorithms/approximate/vomps.jl

Lines changed: 81 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,63 +5,105 @@ Base.@deprecate(approximate(ψ::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:M
55
alg.verbosity, alg.alg_gauge, alg.alg_environments),
66
envs...; kwargs...))
77

8-
function approximate::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:MultilineMPS},
9-
alg::VOMPS, envs=environments(ψ, toapprox))
10-
ϵ::Float64 = calc_galerkin(ψ, toapprox..., envs)
11-
temp_ACs = similar.(ψ.AC)
12-
scheduler = Defaults.scheduler[]
8+
function approximate(mps::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:MultilineMPS},
9+
alg::VOMPS, envs=environments(mps, toapprox))
1310
log = IterLog("VOMPS")
14-
alg_environments = updatetol(alg.alg_environments, 0, ϵ)
15-
recalculate!(envs, ψ, toapprox...; alg_environments.tol)
11+
iter = 0
12+
ϵ = calc_galerkin(mps, toapprox..., envs)
13+
alg_environments = updatetol(alg.alg_environments, iter, ϵ)
14+
recalculate!(envs, mps, toapprox...; alg_environments.tol)
1615

17-
LoggingExtras.withlevel(; alg.verbosity) do
16+
state = VOMPSState(mps, toapprox, envs, iter, ϵ)
17+
it = IterativeSolver(alg, state)
18+
19+
return LoggingExtras.withlevel(; alg.verbosity) do
1820
@infov 2 loginit!(log, ϵ)
19-
for iter in 1:(alg.maxiter)
20-
tmap!(eachcol(temp_ACs), 1:size(ψ, 2); scheduler) do col
21-
return _vomps_localupdate(col, ψ, toapprox, envs)
21+
22+
for (mps, envs, ϵ) in it
23+
if ϵ alg.tol
24+
@infov 2 logfinish!(log, it.iter, ϵ)
25+
return mps, envs, ϵ
26+
end
27+
if it.iter alg.maxiter
28+
@warnv 1 logcancel!(log, it.iter, ϵ)
29+
return mps, envs, ϵ
2230
end
31+
@infov 3 logiter!(log, it.iter, ϵ)
32+
end
2333

24-
alg_gauge = updatetol(alg.alg_gauge, iter, ϵ)
25-
ψ = MultilineMPS(temp_ACs, ψ.C[:, end]; alg_gauge.tol, alg_gauge.maxiter)
34+
# this should never be reached
35+
return it.state.mps, it.state.envs, it.state.ϵ
36+
end
37+
end
2638

27-
alg_environments = updatetol(alg.alg_environments, iter, ϵ)
28-
recalculate!(envs, ψ, toapprox...; alg_environments.tol)
39+
# need to specialize a bunch of functions because different arguments are passed with tuples
40+
# TODO: can we avoid this?
41+
function Base.iterate(it::IterativeSolver{<:VOMPS},
42+
state::VOMPSState{<:Any,<:Tuple}=it.state)
43+
ACs = localupdate_step!(it, state)
44+
mps = gauge_step!(it, state, ACs)
45+
envs = envs_step!(it, state, mps)
2946

30-
ψ, envs = alg.finalize(iter, ψ, toapprox, envs)::Tuple{typeof(ψ),typeof(envs)}
47+
# finalizer step
48+
mps, envs = it.finalize(state.iter, mps, state.operator, envs)::typeof((mps, envs))
3149

32-
ϵ = calc_galerkin(ψ, toapprox..., envs)
50+
# error criterion
51+
ϵ = calc_galerkin(mps, state.operator..., envs)
3352

34-
if ϵ <= alg.tol
35-
@infov 2 logfinish!(log, iter, ϵ)
36-
break
37-
end
38-
if iter == alg.maxiter
39-
@warnv 1 logcancel!(log, iter, ϵ)
40-
else
41-
@infov 3 logiter!(log, iter, ϵ)
42-
end
43-
end
53+
# update state
54+
it.state = VOMPSState(mps, state.operator, envs, state.iter + 1, ϵ)
55+
56+
return (mps, envs, ϵ), it.state
57+
end
58+
59+
# TODO: ac_proj and c_proj should be rewritten to also be simply ∂AC/∂C functions
60+
# once these have better support for different above/below mps
61+
function localupdate_step!(::IterativeSolver{<:VOMPS}, state::VOMPSState{<:Any,<:Tuple},
62+
::SerialScheduler)
63+
alg_orth = QRpos()
64+
eachsite = 1:length(state.mps)
65+
ACs = similar(state.mps.AC)
66+
dst_ACs = state.mps isa Multiline ? eachcol(ACs) : ACs
67+
68+
foreach(eachsite) do site
69+
AC = circshift([ac_proj(row, loc, state.mps, state.toapprox, state.envs)
70+
for row in 1:size(state.mps, 1)], 1)
71+
C = circshift([c_proj(row, loc, state.mps, state.toapprox, state.envs)
72+
for row in 1:size(state.mps, 1)], 1)
73+
dst_ACs[site] = regauge!(AC, C; alg=alg_orth)
74+
return nothing
4475
end
4576

46-
return ψ, envs, ϵ
77+
return ACs
4778
end
79+
function localupdate_step!(::IterativeSolver{<:VOMPS}, state::VOMPSState{<:Any,<:Tuple},
80+
scheduler)
81+
alg_orth = QRpos()
82+
eachsite = 1:length(state.mps)
4883

49-
function _vomps_localupdate(loc, ψ, Oϕ, envs, factalg=QRpos())
50-
local tmp_AC, tmp_C
51-
if Defaults.scheduler[] isa SerialScheduler
52-
tmp_AC = circshift([ac_proj(row, loc, ψ, Oϕ, envs) for row in 1:size(ψ, 1)], 1)
53-
tmp_C = circshift([c_proj(row, loc, ψ, Oϕ, envs) for row in 1:size(ψ, 1)], 1)
54-
else
84+
ACs = similar(state.mps.AC)
85+
dst_ACs = state.mps isa Multiline ? eachcol(ACs) : ACs
86+
87+
tforeach(eachsite; scheduler) do site
88+
local AC, C
5589
@sync begin
5690
Threads.@spawn begin
57-
tmp_AC = circshift([ac_proj(row, loc, ψ, Oϕ, envs)
58-
for row in 1:size(ψ, 1)], 1)
91+
AC = circshift([ac_proj(row, site, state.mps, state.operator, state.envs)
92+
for row in 1:size(state.mps, 1)], 1)
5993
end
6094
Threads.@spawn begin
61-
tmp_C = circshift([c_proj(row, loc, ψ, Oϕ, envs) for row in 1:size(ψ, 1)],
62-
1)
95+
C = circshift([c_proj(row, site, state.mps, state.operator, state.envs)
96+
for row in 1:size(state.mps, 1)], 1)
6397
end
6498
end
99+
dst_ACs[site] = regauge!(AC, C; alg=alg_orth)
100+
return nothing
65101
end
66-
return regauge!.(tmp_AC, tmp_C; alg=factalg)
102+
103+
return ACs
104+
end
105+
106+
function envs_step!(it::IterativeSolver{<:VOMPS}, state::VOMPSState{<:Any,<:Tuple}, mps)
107+
alg_environments = updatetol(it.alg_environments, state.iter, state.ϵ)
108+
return recalculate!(state.envs, mps, state.operator...; alg_environments.tol)
67109
end

src/algorithms/groundstate/vumps.jl

Lines changed: 93 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -35,77 +35,124 @@ $(TYPEDFIELDS)
3535
finalize::F = Defaults._finalize
3636
end
3737

38-
function find_groundstate::InfiniteMPS, H, alg::VUMPS, envs=environments(ψ, H))
38+
# Internal state of the VUMPS algorithm
39+
struct VUMPSState{S,O,E}
40+
mps::S
41+
operator::O
42+
envs::E
43+
iter::Int
44+
ϵ::Float64
45+
which::Symbol
46+
end
47+
48+
function find_groundstate(mps::InfiniteMPS, operator, alg::VUMPS,
49+
envs=environments(mps, operator))
50+
return dominant_eigsolve(operator, mps, alg, envs; which=:SR)
51+
end
52+
53+
function dominant_eigsolve(operator, mps, alg::VUMPS, envs=environments(mps, operator);
54+
which)
3955
log = IterLog("VUMPS")
40-
ϵ::Float64 = calc_galerkin(ψ, H, ψ, envs)
41-
ACs = similar.(ψ.AC)
42-
alg_environments = updatetol(alg.alg_environments, 0, ϵ)
43-
recalculate!(envs, ψ, H, ψ; alg_environments.tol)
56+
iter = 0
57+
ϵ = calc_galerkin(mps, operator, mps, envs)
58+
alg_environments = updatetol(alg.alg_environments, iter, ϵ)
59+
recalculate!(envs, mps, operator, mps; alg_environments.tol)
4460

45-
state = (; ψ, H, envs, ACs, iter=0, ϵ)
61+
state = VUMPSState(mps, operator, envs, iter, ϵ, which)
4662
it = IterativeSolver(alg, state)
4763

4864
return LoggingExtras.withlevel(; alg.verbosity) do
49-
@infov 2 loginit!(log, ϵ, sum(expectation_value(ψ, H, envs)))
65+
@infov 2 loginit!(log, ϵ, sum(expectation_value(mps, operator, envs)))
5066

51-
for (ψ, envs, ϵ) in it
67+
for (mps, envs, ϵ) in it
5268
if ϵ alg.tol
53-
@infov 2 logfinish!(log, it.iter, ϵ, expectation_value(ψ, H, envs))
54-
return ψ, envs, ϵ
69+
@infov 2 logfinish!(log, it.iter, ϵ, expectation_value(mps, operator, envs))
70+
return mps, envs, ϵ
5571
end
5672
if it.iter alg.maxiter
57-
@warnv 1 logcancel!(log, it.iter, ϵ, expectation_value(ψ, H, envs))
58-
return ψ, envs, ϵ
73+
@warnv 1 logcancel!(log, it.iter, ϵ, expectation_value(mps, operator, envs))
74+
return mps, envs, ϵ
5975
end
60-
@infov 3 logiter!(log, it.iter, ϵ, expectation_value(ψ, H, envs))
76+
@infov 3 logiter!(log, it.iter, ϵ, expectation_value(mps, operator, envs))
6177
end
6278

63-
return it.state.ψ, it.state.envs, it.state.ϵ
79+
# this should never be reached
80+
return it.state.mps, it.state.envs, it.state.ϵ
6481
end
6582
end
6683

6784
function Base.iterate(it::IterativeSolver{<:VUMPS}, state=it.state)
68-
# eigsolver step
69-
alg_eigsolve = updatetol(it.alg_eigsolve, state.iter, state.ϵ)
70-
scheduler = Defaults.scheduler[]
71-
ACs = tmap!(state.ACs, 1:length(state.ψ); scheduler) do site
72-
return _vumps_localupdate(site, state.ψ, state.H, state.envs, alg_eigsolve)
73-
end
74-
75-
# gauge step
76-
alg_gauge = updatetol(it.alg_gauge, state.iter, state.ϵ)
77-
ψ = InfiniteMPS(ACs, state.ψ.C[end]; alg_gauge.tol, alg_gauge.maxiter)
78-
79-
# environment step
80-
alg_environments = updatetol(it.alg_environments, state.iter, state.ϵ)
81-
envs = recalculate!(state.envs, ψ, state.H, ψ; alg_environments.tol)
85+
ACs = localupdate_step!(it, state)
86+
mps = gauge_step!(it, state, ACs)
87+
envs = envs_step!(it, state, mps)
8288

8389
# finalizer step
84-
ψ′, envs = it.finalize(state.iter, ψ, state.H, envs)::Tuple{typeof(ψ),typeof(envs)}
90+
mps, envs = it.finalize(state.iter, mps, state.operator, envs)::typeof((mps, envs))
8591

8692
# error criterion
87-
ϵ = calc_galerkin(ψ′, state.H, ψ′, envs)
93+
ϵ = calc_galerkin(mps, state.operator, mps, envs)
8894

8995
# update state
90-
it.state = (; ψ=ψ′, H=state.H, envs=envs′, ACs, iter=state.iter + 1, ϵ)
96+
it.state = VUMPSState(mps, state.operator, envs, state.iter + 1, ϵ, state.which)
9197

92-
return (ψ′, envs, ϵ), it.state
98+
return (mps, envs, ϵ), it.state
9399
end
94100

95-
function _vumps_localupdate(loc, ψ, H, envs, eigalg, factalg=QRpos())
96-
local AC′, C′
97-
if Defaults.scheduler[] isa SerialScheduler
98-
_, AC′ = fixedpoint(∂∂AC(loc, ψ, H, envs), ψ.AC[loc], :SR, eigalg)
99-
_, C′ = fixedpoint(∂∂C(loc, ψ, H, envs), ψ.C[loc], :SR, eigalg)
100-
else
101-
@sync begin
102-
Threads.@spawn begin
103-
_, AC′ = fixedpoint(∂∂AC(loc, ψ, H, envs), ψ.AC[loc], :SR, eigalg)
104-
end
105-
Threads.@spawn begin
106-
_, C′ = fixedpoint(∂∂C(loc, ψ, H, envs), ψ.C[loc], :SR, eigalg)
107-
end
101+
function localupdate_step!(it::IterativeSolver{<:VUMPS}, state,
102+
scheduler=Defaults.scheduler[])
103+
alg_eigsolve = updatetol(it.alg_eigsolve, state.iter, state.ϵ)
104+
alg_orth = QRpos()
105+
106+
mps = state.mps
107+
eachsite = 1:length(mps)
108+
src_Cs = mps isa Multiline ? eachcol(mps.C) : mps.C
109+
src_ACs = mps isa Multiline ? eachcol(mps.AC) : mps.AC
110+
ACs = similar(mps.AC)
111+
dst_ACs = mps isa Multiline ? eachcol(ACs) : ACs
112+
113+
tforeach(eachsite, src_ACs, src_Cs; scheduler) do site, AC₀, C₀
114+
dst_ACs[site] = _localupdate_vumps_step!(site, mps, state.operator, state.envs,
115+
AC₀, C₀; parallel=false, alg_orth,
116+
state.which, alg_eigsolve)
117+
return nothing
118+
end
119+
120+
return ACs
121+
end
122+
123+
function _localupdate_vumps_step!(site, mps, operator, envs, AC₀, C₀;
124+
parallel::Bool=false, alg_orth=QRpos(),
125+
alg_eigsolve=Defaults.eigsolver, which)
126+
if !parallel
127+
_, AC = fixedpoint(∂∂AC(site, mps, operator, envs), AC₀, which, alg_eigsolve)
128+
_, C = fixedpoint(∂∂C(site, mps, operator, envs), C₀, which, alg_eigsolve)
129+
return regauge!(AC, C; alg=alg_orth)
130+
end
131+
132+
local AC, C
133+
@sync begin
134+
@spawn begin
135+
_, AC = fixedpoint(∂∂AC(site, mps, operator, envs),
136+
AC₀, which, alg_eigsolve)
137+
end
138+
@spawn begin
139+
_, C = fixedpoint(∂∂C(site, mps, operator, envs),
140+
C₀, which, alg_eigsolve)
108141
end
109142
end
110-
return regauge!(AC′, C′; alg=factalg)
143+
return regauge!(AC, C; alg=alg_orth)
144+
end
145+
146+
function gauge_step!(it::IterativeSolver{<:VUMPS}, state, ACs::AbstractVector)
147+
alg_gauge = updatetol(it.alg_gauge, state.iter, state.ϵ)
148+
return InfiniteMPS(ACs, state.mps.C[end]; alg_gauge.tol, alg_gauge.maxiter)
149+
end
150+
function gauge_step!(it::IterativeSolver{<:VUMPS}, state, ACs::AbstractMatrix)
151+
alg_gauge = updatetol(it.alg_gauge, state.iter, state.ϵ)
152+
return MultilineMPS(ACs, @view(state.mps.C[:, end]); alg_gauge.tol, alg_gauge.maxiter)
153+
end
154+
155+
function envs_step!(it::IterativeSolver{<:VUMPS}, state, mps)
156+
alg_environments = updatetol(it.alg_environments, state.iter, state.ϵ)
157+
return recalculate!(state.envs, mps, state.operator, mps; alg_environments.tol)
111158
end

0 commit comments

Comments
 (0)