Skip to content

Commit ceda9d2

Browse files
authored
Refactor VUMPS and VOMPS to avoid boxed variables (#265)
* Update OhMyThreads compat * Refactor VUMPS to avoid boxed variables * Loosen derivatives type restriction * Loosen ortho type restriction * Loosen IterativeSolver type restriction * Iterative `VUMPS` and `VOMPS` * Remove duplicate definition
1 parent 487bc7a commit ceda9d2

File tree

8 files changed

+315
-199
lines changed

8 files changed

+315
-199
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
2626
Accessors = "0.1"
2727
Aqua = "0.8.9"
2828
BlockTensorKit = "0.1.4"
29-
Compat = "3.47, 4.10"
3029
Combinatorics = "1"
30+
Compat = "3.47, 4.10"
3131
DocStringExtensions = "0.9.3"
3232
HalfIntegers = "1.6.0"
3333
KrylovKit = "0.8.3, 0.9.2"
3434
LinearAlgebra = "1.6"
3535
LoggingExtras = "~1.0"
36-
OhMyThreads = "0.7.0"
36+
OhMyThreads = "0.7, 0.8"
3737
OptimKit = "0.3.1, 0.4"
3838
Pkg = "1"
3939
Plots = "1.40"

src/algorithms/approximate/vomps.jl

Lines changed: 80 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,63 +5,104 @@ Base.@deprecate(approximate(ψ::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:M
55
alg.verbosity, alg.alg_gauge, alg.alg_environments),
66
envs...; kwargs...))
77

8-
function approximate::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:MultilineMPS},
9-
alg::VOMPS, envs=environments(ψ, toapprox))
10-
ϵ::Float64 = calc_galerkin(ψ, toapprox..., envs)
11-
temp_ACs = similar.(ψ.AC)
12-
scheduler = Defaults.scheduler[]
8+
function approximate(mps::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:MultilineMPS},
9+
alg::VOMPS, envs=environments(mps, toapprox))
1310
log = IterLog("VOMPS")
14-
alg_environments = updatetol(alg.alg_environments, 0, ϵ)
15-
recalculate!(envs, ψ, toapprox...; alg_environments.tol)
11+
iter = 0
12+
ϵ = calc_galerkin(mps, toapprox..., envs)
13+
alg_environments = updatetol(alg.alg_environments, iter, ϵ)
14+
recalculate!(envs, mps, toapprox...; alg_environments.tol)
1615

17-
LoggingExtras.withlevel(; alg.verbosity) do
16+
state = VOMPSState(mps, toapprox, envs, iter, ϵ)
17+
it = IterativeSolver(alg, state)
18+
19+
return LoggingExtras.withlevel(; alg.verbosity) do
1820
@infov 2 loginit!(log, ϵ)
19-
for iter in 1:(alg.maxiter)
20-
tmap!(eachcol(temp_ACs), 1:size(ψ, 2); scheduler) do col
21-
return _vomps_localupdate(col, ψ, toapprox, envs)
21+
22+
for (mps, envs, ϵ) in it
23+
if ϵ alg.tol
24+
@infov 2 logfinish!(log, it.iter, ϵ)
25+
return mps, envs, ϵ
26+
end
27+
if it.iter alg.maxiter
28+
@warnv 1 logcancel!(log, it.iter, ϵ)
29+
return mps, envs, ϵ
2230
end
31+
@infov 3 logiter!(log, it.iter, ϵ)
32+
end
2333

24-
alg_gauge = updatetol(alg.alg_gauge, iter, ϵ)
25-
ψ = MultilineMPS(temp_ACs, ψ.C[:, end]; alg_gauge.tol, alg_gauge.maxiter)
34+
# this should never be reached
35+
return it.state.mps, it.state.envs, it.state.ϵ
36+
end
37+
end
2638

27-
alg_environments = updatetol(alg.alg_environments, iter, ϵ)
28-
recalculate!(envs, ψ, toapprox...; alg_environments.tol)
39+
# need to specialize a bunch of functions because different arguments are passed with tuples
40+
# TODO: can we avoid this?
41+
function Base.iterate(it::IterativeSolver{<:VOMPS}, state::VOMPSState{<:Any,<:Tuple})
42+
ACs = localupdate_step!(it, state)
43+
mps = gauge_step!(it, state, ACs)
44+
envs = envs_step!(it, state, mps)
2945

30-
ψ, envs = alg.finalize(iter, ψ, toapprox, envs)::Tuple{typeof(ψ),typeof(envs)}
46+
# finalizer step
47+
mps, envs = it.finalize(state.iter, mps, state.operator, envs)::typeof((mps, envs))
3148

32-
ϵ = calc_galerkin(ψ, toapprox..., envs)
49+
# error criterion
50+
ϵ = calc_galerkin(mps, state.operator..., envs)
3351

34-
if ϵ <= alg.tol
35-
@infov 2 logfinish!(log, iter, ϵ)
36-
break
37-
end
38-
if iter == alg.maxiter
39-
@warnv 1 logcancel!(log, iter, ϵ)
40-
else
41-
@infov 3 logiter!(log, iter, ϵ)
42-
end
43-
end
52+
# update state
53+
it.state = VOMPSState(mps, state.operator, envs, state.iter + 1, ϵ)
54+
55+
return (mps, envs, ϵ), it.state
56+
end
57+
58+
# TODO: ac_proj and c_proj should be rewritten to also be simply ∂AC/∂C functions
59+
# once these have better support for different above/below mps
60+
function localupdate_step!(::IterativeSolver{<:VOMPS}, state::VOMPSState{<:Any,<:Tuple},
61+
::SerialScheduler)
62+
alg_orth = QRpos()
63+
eachsite = 1:length(state.mps)
64+
ACs = similar(state.mps.AC)
65+
dst_ACs = state.mps isa Multiline ? eachcol(ACs) : ACs
66+
67+
foreach(eachsite) do site
68+
AC = circshift([ac_proj(row, loc, state.mps, state.toapprox, state.envs)
69+
for row in 1:size(state.mps, 1)], 1)
70+
C = circshift([c_proj(row, loc, state.mps, state.toapprox, state.envs)
71+
for row in 1:size(state.mps, 1)], 1)
72+
dst_ACs[site] = regauge!(AC, C; alg=alg_orth)
73+
return nothing
4474
end
4575

46-
return ψ, envs, ϵ
76+
return ACs
4777
end
78+
function localupdate_step!(::IterativeSolver{<:VOMPS}, state::VOMPSState{<:Any,<:Tuple},
79+
scheduler)
80+
alg_orth = QRpos()
81+
eachsite = 1:length(state.mps)
4882

49-
function _vomps_localupdate(loc, ψ, Oϕ, envs, factalg=QRpos())
50-
local tmp_AC, tmp_C
51-
if Defaults.scheduler[] isa SerialScheduler
52-
tmp_AC = circshift([ac_proj(row, loc, ψ, Oϕ, envs) for row in 1:size(ψ, 1)], 1)
53-
tmp_C = circshift([c_proj(row, loc, ψ, Oϕ, envs) for row in 1:size(ψ, 1)], 1)
54-
else
83+
ACs = similar(state.mps.AC)
84+
dst_ACs = state.mps isa Multiline ? eachcol(ACs) : ACs
85+
86+
tforeach(eachsite; scheduler) do site
87+
local AC, C
5588
@sync begin
5689
Threads.@spawn begin
57-
tmp_AC = circshift([ac_proj(row, loc, ψ, Oϕ, envs)
58-
for row in 1:size(ψ, 1)], 1)
90+
AC = circshift([ac_proj(row, site, state.mps, state.operator, state.envs)
91+
for row in 1:size(state.mps, 1)], 1)
5992
end
6093
Threads.@spawn begin
61-
tmp_C = circshift([c_proj(row, loc, ψ, Oϕ, envs) for row in 1:size(ψ, 1)],
62-
1)
94+
C = circshift([c_proj(row, site, state.mps, state.operator, state.envs)
95+
for row in 1:size(state.mps, 1)], 1)
6396
end
6497
end
98+
dst_ACs[site] = regauge!(AC, C; alg=alg_orth)
99+
return nothing
65100
end
66-
return regauge!.(tmp_AC, tmp_C; alg=factalg)
101+
102+
return ACs
103+
end
104+
105+
function envs_step!(it::IterativeSolver{<:VOMPS}, state::VOMPSState{<:Any,<:Tuple}, mps)
106+
alg_environments = updatetol(it.alg_environments, state.iter, state.ϵ)
107+
return recalculate!(state.envs, mps, state.operator...; alg_environments.tol)
67108
end

src/algorithms/derivatives.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ function ∂AC(x::GenericMPSTensor{S,3}, operator::MPOTensor{S}, leftenv::MPSTen
8383
end
8484

8585
# mpo multiline
86-
function ∂AC(x::Vector, opp, leftenv, rightenv)
86+
function ∂AC(x::AbstractVector, opp, leftenv, rightenv)
8787
return circshift(map(∂AC, x, opp, leftenv, rightenv), 1)
8888
end
8989

@@ -109,7 +109,7 @@ function ∂AC2(x::AbstractTensorMap{<:Any,<:Any,3,3}, operator1::MPOTensor,
109109
operator2[7 -6; 4 5] * τ[5 -5; 2 3]
110110
end
111111

112-
function ∂AC2(x::Vector, opp1, opp2, leftenv, rightenv)
112+
function ∂AC2(x::AbstractVector, opp1, opp2, leftenv, rightenv)
113113
return circshift(map(∂AC2, x, opp1, opp2, leftenv, rightenv), 1)
114114
end
115115

@@ -122,7 +122,7 @@ function ∂C(x::MPSBondTensor, leftenv::MPSBondTensor, rightenv::MPSBondTensor)
122122
@plansor toret[-1; -2] := leftenv[-1; 1] * x[1; 2] * rightenv[2; -2]
123123
end
124124

125-
function ∂C(x::Vector, leftenv, rightenv)
125+
function ∂C(x::AbstractVector, leftenv, rightenv)
126126
return circshift(map(t -> ∂C(t...), zip(x, leftenv, rightenv)), 1)
127127
end
128128

src/algorithms/groundstate/vumps.jl

Lines changed: 106 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -35,63 +35,124 @@ $(TYPEDFIELDS)
3535
finalize::F = Defaults._finalize
3636
end
3737

38-
function find_groundstate::InfiniteMPS, H, alg::VUMPS, envs=environments(ψ, H))
39-
# initialization
40-
scheduler = Defaults.scheduler[]
41-
log = IterLog("VUMPS")
42-
ϵ::Float64 = calc_galerkin(ψ, H, ψ, envs)
43-
temp_ACs = similar.(ψ.AC)
44-
alg_environments = updatetol(alg.alg_environments, 0, ϵ)
45-
recalculate!(envs, ψ, H, ψ; alg_environments.tol)
46-
47-
LoggingExtras.withlevel(; alg.verbosity) do
48-
@infov 2 loginit!(log, ϵ, sum(expectation_value(ψ, H, envs)))
49-
for iter in 1:(alg.maxiter)
50-
alg_eigsolve = updatetol(alg.alg_eigsolve, iter, ϵ)
51-
tmap!(temp_ACs, 1:length(ψ); scheduler) do loc
52-
return _vumps_localupdate(loc, ψ, H, envs, alg_eigsolve)
53-
end
38+
# Internal state of the VUMPS algorithm
39+
struct VUMPSState{S,O,E}
40+
mps::S
41+
operator::O
42+
envs::E
43+
iter::Int
44+
ϵ::Float64
45+
which::Symbol
46+
end
5447

55-
alg_gauge = updatetol(alg.alg_gauge, iter, ϵ)
56-
ψ = InfiniteMPS(temp_ACs, ψ.C[end]; alg_gauge.tol, alg_gauge.maxiter)
48+
function find_groundstate(mps::InfiniteMPS, operator, alg::VUMPS,
49+
envs=environments(mps, operator))
50+
return dominant_eigsolve(operator, mps, alg, envs; which=:SR)
51+
end
5752

58-
alg_environments = updatetol(alg.alg_environments, iter, ϵ)
59-
recalculate!(envs, ψ, H, ψ; alg_environments.tol)
53+
function dominant_eigsolve(operator, mps, alg::VUMPS, envs=environments(mps, operator);
54+
which)
55+
log = IterLog("VUMPS")
56+
iter = 0
57+
ϵ = calc_galerkin(mps, operator, mps, envs)
58+
alg_environments = updatetol(alg.alg_environments, iter, ϵ)
59+
recalculate!(envs, mps, operator, mps; alg_environments.tol)
6060

61-
ψ, envs = alg.finalize(iter, ψ, H, envs)::Tuple{typeof(ψ),typeof(envs)}
61+
state = VUMPSState(mps, operator, envs, iter, ϵ, which)
62+
it = IterativeSolver(alg, state)
6263

63-
ϵ = calc_galerkin(ψ, H, ψ, envs)
64+
return LoggingExtras.withlevel(; alg.verbosity) do
65+
@infov 2 loginit!(log, ϵ, sum(expectation_value(mps, operator, envs)))
6466

65-
# breaking conditions
66-
if ϵ <= alg.tol
67-
@infov 2 logfinish!(log, iter, ϵ, expectation_value(ψ, H, envs))
68-
break
67+
for (mps, envs, ϵ) in it
68+
if ϵ alg.tol
69+
@infov 2 logfinish!(log, it.iter, ϵ, expectation_value(mps, operator, envs))
70+
return mps, envs, ϵ
6971
end
70-
if iter == alg.maxiter
71-
@warnv 1 logcancel!(log, iter, ϵ, expectation_value(ψ, H, envs))
72-
else
73-
@infov 3 logiter!(log, iter, ϵ, expectation_value(ψ, H, envs))
72+
if it.iter alg.maxiter
73+
@warnv 1 logcancel!(log, it.iter, ϵ, expectation_value(mps, operator, envs))
74+
return mps, envs, ϵ
7475
end
76+
@infov 3 logiter!(log, it.iter, ϵ, expectation_value(mps, operator, envs))
7577
end
78+
79+
# this should never be reached
80+
return it.state.mps, it.state.envs, it.state.ϵ
7681
end
82+
end
7783

78-
return ψ, envs, ϵ
84+
function Base.iterate(it::IterativeSolver{<:VUMPS}, state=it.state)
85+
ACs = localupdate_step!(it, state)
86+
mps = gauge_step!(it, state, ACs)
87+
envs = envs_step!(it, state, mps)
88+
89+
# finalizer step
90+
mps, envs = it.finalize(state.iter, mps, state.operator, envs)::typeof((mps, envs))
91+
92+
# error criterion
93+
ϵ = calc_galerkin(mps, state.operator, mps, envs)
94+
95+
# update state
96+
it.state = VUMPSState(mps, state.operator, envs, state.iter + 1, ϵ, state.which)
97+
98+
return (mps, envs, ϵ), it.state
7999
end
80100

81-
function _vumps_localupdate(loc, ψ, H, envs, eigalg, factalg=QRpos())
82-
local AC′, C′
83-
if Defaults.scheduler[] isa SerialScheduler
84-
_, AC′ = fixedpoint(∂∂AC(loc, ψ, H, envs), ψ.AC[loc], :SR, eigalg)
85-
_, C′ = fixedpoint(∂∂C(loc, ψ, H, envs), ψ.C[loc], :SR, eigalg)
86-
else
87-
@sync begin
88-
Threads.@spawn begin
89-
_, AC′ = fixedpoint(∂∂AC(loc, ψ, H, envs), ψ.AC[loc], :SR, eigalg)
90-
end
91-
Threads.@spawn begin
92-
_, C′ = fixedpoint(∂∂C(loc, ψ, H, envs), ψ.C[loc], :SR, eigalg)
93-
end
101+
function localupdate_step!(it::IterativeSolver{<:VUMPS}, state,
102+
scheduler=Defaults.scheduler[])
103+
alg_eigsolve = updatetol(it.alg_eigsolve, state.iter, state.ϵ)
104+
alg_orth = QRpos()
105+
106+
mps = state.mps
107+
eachsite = 1:length(mps)
108+
src_Cs = mps isa Multiline ? eachcol(mps.C) : mps.C
109+
src_ACs = mps isa Multiline ? eachcol(mps.AC) : mps.AC
110+
ACs = similar(mps.AC)
111+
dst_ACs = mps isa Multiline ? eachcol(ACs) : ACs
112+
113+
tforeach(eachsite, src_ACs, src_Cs; scheduler) do site, AC₀, C₀
114+
dst_ACs[site] = _localupdate_vumps_step!(site, mps, state.operator, state.envs,
115+
AC₀, C₀; parallel=false, alg_orth,
116+
state.which, alg_eigsolve)
117+
return nothing
118+
end
119+
120+
return ACs
121+
end
122+
123+
function _localupdate_vumps_step!(site, mps, operator, envs, AC₀, C₀;
124+
parallel::Bool=false, alg_orth=QRpos(),
125+
alg_eigsolve=Defaults.eigsolver, which)
126+
if !parallel
127+
_, AC = fixedpoint(∂∂AC(site, mps, operator, envs), AC₀, which, alg_eigsolve)
128+
_, C = fixedpoint(∂∂C(site, mps, operator, envs), C₀, which, alg_eigsolve)
129+
return regauge!(AC, C; alg=alg_orth)
130+
end
131+
132+
local AC, C
133+
@sync begin
134+
@spawn begin
135+
_, AC = fixedpoint(∂∂AC(site, mps, operator, envs),
136+
AC₀, which, alg_eigsolve)
137+
end
138+
@spawn begin
139+
_, C = fixedpoint(∂∂C(site, mps, operator, envs),
140+
C₀, which, alg_eigsolve)
94141
end
95142
end
96-
return regauge!(AC′, C′; alg=factalg)
143+
return regauge!(AC, C; alg=alg_orth)
144+
end
145+
146+
function gauge_step!(it::IterativeSolver{<:VUMPS}, state, ACs::AbstractVector)
147+
alg_gauge = updatetol(it.alg_gauge, state.iter, state.ϵ)
148+
return InfiniteMPS(ACs, state.mps.C[end]; alg_gauge.tol, alg_gauge.maxiter)
149+
end
150+
function gauge_step!(it::IterativeSolver{<:VUMPS}, state, ACs::AbstractMatrix)
151+
alg_gauge = updatetol(it.alg_gauge, state.iter, state.ϵ)
152+
return MultilineMPS(ACs, @view(state.mps.C[:, end]); alg_gauge.tol, alg_gauge.maxiter)
153+
end
154+
155+
function envs_step!(it::IterativeSolver{<:VUMPS}, state, mps)
156+
alg_environments = updatetol(it.alg_environments, state.iter, state.ϵ)
157+
return recalculate!(state.envs, mps, state.operator, mps; alg_environments.tol)
97158
end

0 commit comments

Comments
 (0)