Skip to content

Commit d86b5fa

Browse files
Gertianlkdvos
andauthored
Parallelization through OhMyThreads.jl (#219)
* redid parallel_GD in new MPSKit * Defaults -> MPSKit.Defaults * fixed. SOme more things can be made parallel but then I get errors... * Workign version * formatting * switch out threading preferences for OhMyThreads * update VUMPS update statmech vumps * update VOMPS * update TDVP * update QP * update Grassmann * remove unused code * Add missing scheduler * restrict GradientGrassmann linesearch maxiter * update docs on multithreading [skip ci] * replace `tmap` with `tmap!` * increase maxiter on linesearch * Fix nasty race condition * bugfix diagonal C (again) * Add `check_recalculate!` for MultipleEnvironments * improve test robustness and performance --------- Co-authored-by: Lukas Devos <[email protected]>
1 parent 24c1630 commit d86b5fa

File tree

18 files changed

+212
-233
lines changed

18 files changed

+212
-233
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ HalfIntegers = "f0d1745a-41c9-11e9-1dd9-e5d34d218721"
1313
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1515
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
16+
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
1617
OptimKit = "77e91f04-9b3b-57a6-a776-40b61faaebe0"
17-
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1818
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1919
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2020
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
@@ -35,10 +35,10 @@ HalfIntegers = "1.6.0"
3535
KrylovKit = "0.8.3"
3636
LinearAlgebra = "1.6"
3737
LoggingExtras = "~1.0"
38+
OhMyThreads = "0.7.0"
3839
OptimKit = "0.3.1"
3940
Pkg = "1"
4041
Plots = "1.40"
41-
Preferences = "1"
4242
Printf = "1"
4343
Random = "1"
4444
RecipesBase = "1.1"

docs/src/man/parallelism.md

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -81,31 +81,22 @@ BLAS: libmkl_rt.so
8181

8282
## MPSKit multithreading
8383

84-
Within MPSKit, when Julia is started with multiple threads, by default the `Threads.@spawn`
85-
machinery will be used to parallelize the code as much as possible. In particular, there are
86-
three main places where this is happening, which can be disabled separately through a preference-based system.
84+
Within MPSKit, when Julia is started with multiple threads, by default the `OhMyThreads.jl`
85+
machinery will be used to parallelize the code as much as possible. In particular, this mostly
86+
occurs whenever there is a unitcell and local updates can take place at each site in parallel.
8787

88-
1. During the process of some algorithms (e.g. VUMPS), local updates can take place at each
89-
site in parallel. This can be controlled by the `parallelize_sites` preference.
90-
91-
2. During the calculation of the environments, when the MPO is block-sparse, it is possible
92-
to parallelize over these blocks. This can be enabled or disabled by the
93-
`parallelize_transfers` preference. (Note that left- and right environments will always
94-
be computed in parallel)
95-
96-
3. During the calculation of the derivatives, when the MPO is block-sparse, it is possible
97-
to parallelize over these blocks. This can be enabled or disabled by the
98-
`parallelize_derivatives` preference.
99-
100-
For convenience, these preferences can be set via [`MPSKit.Defaults.set_parallelization`](@ref), which takes as input pairs of preferences and booleans. For example, to disable all parallelization, one can call
88+
The multithreading behaviour can be controlled through a global `scheduler`, which can be set
89+
using the `MPSKit.Defaults.set_scheduler!(arg; kwargs...)` function. This function accepts
90+
either a `Symbol`, an `OhMyThreads.Scheduler` or keywords to determine a scheduler automatically.
10191

10292
```julia
103-
Defaults.set_parallelization("sites" => false, "transfers" => false, "derivatives" => false)
93+
MPSKit.Defaults.set_scheduler!(:serial) # disable multithreading
94+
MPSKit.Defaults.set_scheduler!(:greedy) # multithreading with greedy load-balancing
95+
MPSKit.Defaults.set_scheduler!(:dynamic) # default: multithreading with some load-balancing
10496
```
10597

106-
!!! warning
107-
These settings are statically set at compile-time, and for changes to take
108-
effect the Julia session must be restarted.
98+
For further reference on the available schedulers and finer control, please refer to the
99+
[`OhMyThreads.jl` documentation](https://juliafolds2.github.io/OhMyThreads.jl/stable/)
109100

110101
## TensorKit multithreading
111102

src/MPSKit.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ using LinearAlgebra: LinearAlgebra
1717
using Random
1818
using Base: @kwdef
1919
using LoggingExtras
20+
using OhMyThreads
2021

2122
# bells and whistles for mpses
2223
export InfiniteMPS, FiniteMPS, WindowMPS, MultilineMPS
@@ -159,4 +160,9 @@ include("algorithms/ED.jl")
159160

160161
include("algorithms/unionalg.jl")
161162

163+
function __init__()
164+
Defaults.set_scheduler!()
165+
return nothing
166+
end
167+
162168
end

src/algorithms/approximate/vomps.jl

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,14 @@ function approximate(ψ::MultilineMPS, toapprox::Tuple{<:MultilineMPO,<:Multilin
2020
alg::VOMPS, envs=environments(ψ, toapprox))
2121
ϵ::Float64 = calc_galerkin(ψ, envs)
2222
temp_ACs = similar.(ψ.AC)
23+
scheduler = Defaults.scheduler[]
2324
log = IterLog("VOMPS")
2425

2526
LoggingExtras.withlevel(; alg.verbosity) do
2627
@infov 2 loginit!(log, ϵ)
2728
for iter in 1:(alg.maxiter)
28-
@static if Defaults.parallelize_sites
29-
@sync for col in 1:size(ψ, 2)
30-
Threads.@spawn begin
31-
temp_ACs[:, col] = _vomps_localupdate(col, ψ, toapprox, envs)
32-
end
33-
end
34-
else
35-
for col in 1:size(ψ, 2)
36-
temp_ACs[:, col] = _vomps_localupdate(col, ψ, toapprox, envs)
37-
end
29+
tmap!(eachcol(temp_ACs), 1:size(ψ, 2); scheduler) do col
30+
return _vomps_localupdate(col, ψ, toapprox, envs)
3831
end
3932

4033
alg_gauge = updatetol(alg.alg_gauge, iter, ϵ)
@@ -64,7 +57,10 @@ end
6457

6558
function _vomps_localupdate(loc, ψ, (O, ψ₀), envs, factalg=QRpos())
6659
local tmp_AC, tmp_C
67-
@static if Defaults.parallelize_sites
60+
if Defaults.scheduler[] isa SerialScheduler
61+
tmp_AC = circshift([ac_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)
62+
tmp_C = circshift([c_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)
63+
else
6864
@sync begin
6965
Threads.@spawn begin
7066
tmp_AC = circshift([ac_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)
@@ -73,9 +69,6 @@ function _vomps_localupdate(loc, ψ, (O, ψ₀), envs, factalg=QRpos())
7369
tmp_C = circshift([c_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)
7470
end
7571
end
76-
else
77-
tmp_AC = circshift([ac_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)
78-
tmp_C = circshift([c_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)
7972
end
8073
return regauge!.(tmp_AC, tmp_C; alg=factalg)
8174
end

src/algorithms/changebonds/svdcut.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ function changebonds(ψ::InfiniteMPS, alg::SvdCut)
9191
ψ = if space(ncr, 1) != space(copied[1], 1)
9292
InfiniteMPS(copied)
9393
else
94-
C₀ = TensorMap(complex(ncr))
94+
C₀ = ncr isa TensorMap ? complex(ncr) : TensorMap(complex(ncr))
9595
InfiniteMPS(copied, C₀)
9696
end
9797
return normalize!(ψ)

src/algorithms/derivatives.jl

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -105,25 +105,6 @@ function ∂AC2(x::Vector, opp1, opp2, leftenv, rightenv)
105105
return circshift(map(∂AC2, x, opp1, opp2, leftenv, rightenv), 1)
106106
end
107107

108-
"""
109-
Zero-site derivative (the C matrix to the right of pos)
110-
"""
111-
function ∂C(x::MPSBondTensor, leftenv::AbstractVector, rightenv::AbstractVector)::typeof(x)
112-
if Defaults.parallelize_derivatives
113-
@floop WorkStealingEx() for (le, re) in zip(leftenv, rightenv)
114-
t = ∂C(x, le, re)
115-
@reduce(y = inplace_add!(nothing, t))
116-
end
117-
else
118-
y = ∂C(x, leftenv[1], rightenv[1])
119-
for (le, re) in zip(leftenv[2:end], rightenv[2:end])
120-
VectorInterface.add!(y, ∂C(x, le, re))
121-
end
122-
end
123-
124-
return y
125-
end
126-
127108
function ∂C(x::MPSBondTensor, leftenv::MPSTensor, rightenv::MPSTensor)
128109
@plansor y[-1; -2] := leftenv[-1 3; 1] * x[1; 2] * rightenv[2 3; -2]
129110
return y isa BlockTensorMap ? only(y) : y

src/algorithms/excitation/quasiparticleexcitation.jl

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -247,19 +247,9 @@ function effective_excitation_hamiltonian(H::Union{InfiniteMPOHamiltonian,
247247
envs.leftenvs,
248248
envs.rightenvs))
249249
ϕ′ = similar(ϕ)
250-
@static if Defaults.parallelize_sites
251-
@sync for loc in 1:length(ϕ)
252-
Threads.@spawn begin
253-
ϕ′[loc] = _effective_excitation_local_apply(loc, ϕ, H, energy[loc],
254-
envs)
255-
end
256-
end
257-
else
258-
for loc in 1:length(ϕ)
259-
ϕ′[loc] = _effective_excitation_local_apply(loc, ϕ, H, energy[loc], envs)
260-
end
250+
tforeach(1:length(ϕ); scheduler=Defaults.scheduler[]) do loc
251+
return ϕ′[loc] = _effective_excitation_local_apply(loc, ϕ, H, energy[loc], envs)
261252
end
262-
263253
return ϕ′
264254
end
265255

src/algorithms/grassmann.jl

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ module GrassmannMPS
1212

1313
using ..MPSKit
1414
using TensorKit
15+
using OhMyThreads
1516
import TensorKitManifolds.Grassmann
1617

1718
function TensorKit.rmul!(a::Grassmann.GrassmannTangent, b::AbstractTensorMap)
@@ -67,23 +68,38 @@ struct ManifoldPoint{T,E,G,C}
6768
Rhoreg::C # the regularized density matrices
6869
end
6970

70-
function ManifoldPoint(state::Union{InfiniteMPS,FiniteMPS}, envs)
71+
function ManifoldPoint(state::FiniteMPS, envs)
7172
al_d = similar(state.AL)
7273
O = envs.operator
7374
for i in 1:length(state)
7475
al_d[i] = MPSKit.∂∂AC(i, state, O, envs) * state.AC[i]
7576
end
76-
7777
g = Grassmann.project.(al_d, state.AL)
7878

7979
Rhoreg = Vector{eltype(state.C)}(undef, length(state))
8080
δmin = sqrt(eps(real(scalartype(state))))
81-
for i in 1:length(state)
82-
Rhoreg[i] = regularize(state.C[i], max(norm(g[i]) / 10, δmin))
81+
tmap!(Rhoreg, 1:length(state); scheduler=MPSKit.Defaults.scheduler[]) do i
82+
return regularize(state.C[i], max(norm(g[i]) / 10, δmin))
8383
end
8484

8585
return ManifoldPoint(state, envs, g, Rhoreg)
8686
end
87+
function ManifoldPoint(state::InfiniteMPS, envs)
88+
δmin = sqrt(eps(real(scalartype(state))))
89+
Tg = Core.Compiler.return_type(Grassmann.project,
90+
Tuple{eltype(state.AL),eltype(state.AL)})
91+
g = similar(state.AL, Tg)
92+
ρ = similar(state.C)
93+
94+
MPSKit.check_recalculate!(envs, state)
95+
tforeach(1:length(state); scheduler=MPSKit.Defaults.scheduler[]) do i
96+
AC′ = MPSKit.∂∂AC(i, state, envs.operator, envs) * state.AC[i]
97+
g[i] = Grassmann.project(AC′, state.AL[i])
98+
ρ[i] = regularize(state.C[i], max(norm(g[i]) / 10, δmin))
99+
return nothing
100+
end
101+
return ManifoldPoint(state, envs, g, ρ)
102+
end
87103

88104
function ManifoldPoint(state::MultilineMPS, envs)
89105
# FIXME: add support for unitcells
@@ -115,10 +131,10 @@ cell as tangent vectors on Grassmann manifolds.
115131
"""
116132
function fg(x::ManifoldPoint{T}) where {T<:Union{InfiniteMPS,FiniteMPS}}
117133
# the gradient I want to return is the preconditioned gradient!
118-
g_prec = Vector{PrecGrad{eltype(x.g),eltype(x.Rhoreg)}}(undef, length(x.g))
119-
120-
for i in 1:length(x.state)
121-
g_prec[i] = PrecGrad(rmul!(copy(x.g[i]), x.state.C[i]'), x.Rhoreg[i])
134+
Tg = Core.Compiler.return_type(PrecGrad, Tuple{eltype(x.g),eltype(x.Rhoreg)})
135+
g_prec = similar(x.g, Tg)
136+
tmap!(g_prec, eachindex(x.g); scheduler=MPSKit.Defaults.scheduler[]) do i
137+
return PrecGrad(rmul!(copy(x.g[i]), x.state.C[i]'), x.Rhoreg[i])
122138
end
123139

124140
# TODO: the operator really should not be part of the environments, and this should
@@ -151,10 +167,10 @@ function retract(x::ManifoldPoint{<:MultilineMPS}, tg, alpha)
151167
g = reshape(tg, size(x.state))
152168

153169
nal = similar(x.state.AL)
154-
h = similar(g)
155-
for (i, cg) in enumerate(tg)
156-
(nal[i], th) = Grassmann.retract(x.state.AL[i], cg.Pg, alpha)
157-
h[i] = PrecGrad(th)
170+
h = similar(tg)
171+
tmap!(h, eachindex(g); scheduler=MPSKit.Defaults.scheduler[]) do i
172+
nal[i], th = Grassmann.retract(x.state.AL[i], g[i].Pg, alpha)
173+
return PrecGrad(th)
158174
end
159175

160176
nstate = MPSKit.MultilineMPS(nal, x.state.C[:, end])
@@ -171,9 +187,10 @@ function retract(x::ManifoldPoint{<:InfiniteMPS}, g, alpha)
171187
envs = x.envs
172188
nal = similar(state.AL)
173189
h = similar(g) # The tangent at the end-point
174-
for i in 1:length(g)
190+
191+
tmap!(h, eachindex(g); scheduler=MPSKit.Defaults.scheduler[]) do i
175192
nal[i], th = Grassmann.retract(state.AL[i], g[i].Pg, alpha)
176-
h[i] = PrecGrad(th)
193+
return PrecGrad(th)
177194
end
178195

179196
nstate = InfiniteMPS(nal, state.C[end])
@@ -209,9 +226,9 @@ Transport a tangent vector `h` along the retraction from `x` in direction `g` by
209226
`alpha`. `xp` is the end-point of the retraction.
210227
"""
211228
function transport!(h, x, g, alpha, xp)
212-
for i in 1:length(h)
213-
h[i] = PrecGrad(Grassmann.transport!(h[i].Pg, x.state.AL[i], g[i].Pg, alpha,
214-
xp.state.AL[i]))
229+
tforeach(1:length(h); scheduler=MPSKit.Defaults.scheduler[]) do i
230+
return h[i] = PrecGrad(Grassmann.transport!(h[i].Pg, x.state.AL[i], g[i].Pg, alpha,
231+
xp.state.AL[i]))
215232
end
216233
return h
217234
end

src/algorithms/groundstate/gradient_grassmann.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,17 @@ struct GradientGrassmann{O<:OptimKit.OptimizationAlgorithm,F} <: Algorithm
2828
finalize!::F
2929

3030
function GradientGrassmann(; method=ConjugateGradient, (finalize!)=OptimKit._finalize!,
31-
tol=Defaults.tol, maxiter=Defaults.maxiter, verbosity=2)
31+
tol=Defaults.tol, maxiter=Defaults.maxiter,
32+
verbosity=Defaults.verbosity - 1)
3233
if isa(method, OptimKit.OptimizationAlgorithm)
3334
# We were given an optimisation method, just use it.
3435
m = method
3536
elseif method <: OptimKit.OptimizationAlgorithm
3637
# We were given an optimisation method type, construct an instance of it.
37-
m = method(; maxiter=maxiter, verbosity=verbosity, gradtol=tol)
38+
# restrict linesearch maxiter
39+
linesearch = OptimKit.HagerZhangLineSearch(; verbosity=verbosity - 2,
40+
maxiter=100)
41+
m = method(; maxiter, verbosity, gradtol=tol, linesearch)
3842
else
3943
msg = "method should be either an instance or a subtype of `OptimKit.OptimizationAlgorithm`."
4044
throw(ArgumentError(msg))

src/algorithms/groundstate/vumps.jl

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,15 @@ function find_groundstate(ψ::InfiniteMPS, H, alg::VUMPS, envs=environments(ψ,
3030
# initialization
3131
ϵ::Float64 = calc_galerkin(ψ, envs)
3232
temp_ACs = similar.(ψ.AC)
33+
scheduler = Defaults.scheduler[]
3334
log = IterLog("VUMPS")
3435

3536
LoggingExtras.withlevel(; alg.verbosity) do
3637
@infov 2 loginit!(log, ϵ, sum(expectation_value(ψ, H, envs)))
3738
for iter in 1:(alg.maxiter)
3839
alg_eigsolve = updatetol(alg.alg_eigsolve, iter, ϵ)
39-
@static if Defaults.parallelize_sites
40-
@sync for loc in 1:length(ψ)
41-
Threads.@spawn begin
42-
temp_ACs[loc] = _vumps_localupdate(loc, ψ, H, envs, alg_eigsolve)
43-
end
44-
end
45-
else
46-
for loc in 1:length(ψ)
47-
temp_ACs[loc] = _vumps_localupdate(loc, ψ, H, envs, alg_eigsolve)
48-
end
40+
tmap!(temp_ACs, 1:length(ψ); scheduler) do loc
41+
return _vumps_localupdate(loc, ψ, H, envs, alg_eigsolve)
4942
end
5043

5144
alg_gauge = updatetol(alg.alg_gauge, iter, ϵ)
@@ -76,7 +69,10 @@ end
7669

7770
function _vumps_localupdate(loc, ψ, H, envs, eigalg, factalg=QRpos())
7871
local AC′, C′
79-
@static if Defaults.parallelize_sites
72+
if Defaults.scheduler[] isa SerialScheduler
73+
_, AC′ = fixedpoint(∂∂AC(loc, ψ, H, envs), ψ.AC[loc], :SR, eigalg)
74+
_, C′ = fixedpoint(∂∂C(loc, ψ, H, envs), ψ.C[loc], :SR, eigalg)
75+
else
8076
@sync begin
8177
Threads.@spawn begin
8278
_, AC′ = fixedpoint(∂∂AC(loc, ψ, H, envs), ψ.AC[loc], :SR, eigalg)
@@ -85,9 +81,6 @@ function _vumps_localupdate(loc, ψ, H, envs, eigalg, factalg=QRpos())
8581
_, C′ = fixedpoint(∂∂C(loc, ψ, H, envs), ψ.C[loc], :SR, eigalg)
8682
end
8783
end
88-
else
89-
_, AC′ = fixedpoint(∂∂AC(loc, ψ, H, envs), ψ.AC[loc], :SR, eigalg)
90-
_, C′ = fixedpoint(∂∂C(loc, ψ, H, envs), ψ.C[loc], :SR, eigalg)
9184
end
9285
return regauge!(AC′, C′; alg=factalg)
9386
end

0 commit comments

Comments
 (0)