Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ HalfIntegers = "f0d1745a-41c9-11e9-1dd9-e5d34d218721"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
OptimKit = "77e91f04-9b3b-57a6-a776-40b61faaebe0"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Expand All @@ -35,10 +35,10 @@ HalfIntegers = "1.6.0"
KrylovKit = "0.8.3"
LinearAlgebra = "1.6"
LoggingExtras = "~1.0"
OhMyThreads = "0.7.0"
OptimKit = "0.3.1"
Pkg = "1"
Plots = "1.40"
Preferences = "1"
Printf = "1"
Random = "1"
RecipesBase = "1.1"
Expand Down
31 changes: 11 additions & 20 deletions docs/src/man/parallelism.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,31 +81,22 @@ BLAS: libmkl_rt.so

## MPSKit multithreading

Within MPSKit, when Julia is started with multiple threads, by default the `Threads.@spawn`
machinery will be used to parallelize the code as much as possible. In particular, there are
three main places where this is happening, which can be disabled separately through a preference-based system.
Within MPSKit, when Julia is started with multiple threads, by default the `OhMyThreads.jl`
machinery will be used to parallelize the code as much as possible. In particular, this mostly
occurs whenever there is a unitcell and local updates can take place at each site in parallel.

1. During the process of some algorithms (e.g. VUMPS), local updates can take place at each
site in parallel. This can be controlled by the `parallelize_sites` preference.

2. During the calculation of the environments, when the MPO is block-sparse, it is possible
to parallelize over these blocks. This can be enabled or disabled by the
`parallelize_transfers` preference. (Note that left- and right environments will always
be computed in parallel)

3. During the calculation of the derivatives, when the MPO is block-sparse, it is possible
to parallelize over these blocks. This can be enabled or disabled by the
`parallelize_derivatives` preference.

For convenience, these preferences can be set via [`MPSKit.Defaults.set_parallelization`](@ref), which takes as input pairs of preferences and booleans. For example, to disable all parallelization, one can call
The multithreading behaviour can be controlled through a global `scheduler`, which can be set
using the `MPSKit.Defaults.set_scheduler!(arg; kwargs...)` function. This function accepts
either a `Symbol`, an `OhMyThreads.Scheduler` or keywords to determine a scheduler automatically.

```julia
Defaults.set_parallelization("sites" => false, "transfers" => false, "derivatives" => false)
MPSKit.Defaults.set_scheduler!(:serial) # disable multithreading
MPSKit.Defaults.set_scheduler!(:greedy) # multithreading with greedy load-balancing
MPSKit.Defaults.set_scheduler!(:dynamic) # default: multithreading with some load-balancing
```

!!! warning
These settings are statically set at compile-time, and for changes to take
effect the Julia session must be restarted.
For further reference on the available schedulers and finer control, please refer to the
[`OhMyThreads.jl` documentation](https://juliafolds2.github.io/OhMyThreads.jl/stable/)

## TensorKit multithreading

Expand Down
6 changes: 6 additions & 0 deletions src/MPSKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using LinearAlgebra: LinearAlgebra
using Random
using Base: @kwdef
using LoggingExtras
using OhMyThreads

# bells and whistles for mpses
export InfiniteMPS, FiniteMPS, WindowMPS, MultilineMPS
Expand Down Expand Up @@ -159,4 +160,9 @@ include("algorithms/ED.jl")

include("algorithms/unionalg.jl")

function __init__()
Defaults.set_scheduler!()
return nothing
end

end
21 changes: 7 additions & 14 deletions src/algorithms/approximate/vomps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,14 @@
alg::VOMPS, envs=environments(ψ, toapprox))
ϵ::Float64 = calc_galerkin(ψ, envs)
temp_ACs = similar.(ψ.AC)
scheduler = Defaults.scheduler[]
log = IterLog("VOMPS")

LoggingExtras.withlevel(; alg.verbosity) do
@infov 2 loginit!(log, ϵ)
for iter in 1:(alg.maxiter)
@static if Defaults.parallelize_sites
@sync for col in 1:size(ψ, 2)
Threads.@spawn begin
temp_ACs[:, col] = _vomps_localupdate(col, ψ, toapprox, envs)
end
end
else
for col in 1:size(ψ, 2)
temp_ACs[:, col] = _vomps_localupdate(col, ψ, toapprox, envs)
end
tmap!(eachcol(temp_ACs), 1:size(ψ, 2); scheduler) do col
return _vomps_localupdate(col, ψ, toapprox, envs)
end

alg_gauge = updatetol(alg.alg_gauge, iter, ϵ)
Expand Down Expand Up @@ -64,7 +57,10 @@

function _vomps_localupdate(loc, ψ, (O, ψ₀), envs, factalg=QRpos())
local tmp_AC, tmp_C
@static if Defaults.parallelize_sites
if Defaults.scheduler[] isa SerialScheduler
tmp_AC = circshift([ac_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)
tmp_C = circshift([c_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)

Check warning on line 62 in src/algorithms/approximate/vomps.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/approximate/vomps.jl#L61-L62

Added lines #L61 - L62 were not covered by tests
else
@sync begin
Threads.@spawn begin
tmp_AC = circshift([ac_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)
Expand All @@ -73,9 +69,6 @@
tmp_C = circshift([c_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)
end
end
else
tmp_AC = circshift([ac_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)
tmp_C = circshift([c_proj(row, loc, ψ, envs) for row in 1:size(ψ, 1)], 1)
end
return regauge!.(tmp_AC, tmp_C; alg=factalg)
end
2 changes: 1 addition & 1 deletion src/algorithms/changebonds/svdcut.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ function changebonds(ψ::InfiniteMPS, alg::SvdCut)
ψ = if space(ncr, 1) != space(copied[1], 1)
InfiniteMPS(copied)
else
C₀ = TensorMap(complex(ncr))
C₀ = ncr isa TensorMap ? complex(ncr) : TensorMap(complex(ncr))
InfiniteMPS(copied, C₀)
end
return normalize!(ψ)
Expand Down
19 changes: 0 additions & 19 deletions src/algorithms/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,25 +105,6 @@ function ∂AC2(x::Vector, opp1, opp2, leftenv, rightenv)
return circshift(map(∂AC2, x, opp1, opp2, leftenv, rightenv), 1)
end

"""
Zero-site derivative (the C matrix to the right of pos)
"""
function ∂C(x::MPSBondTensor, leftenv::AbstractVector, rightenv::AbstractVector)::typeof(x)
if Defaults.parallelize_derivatives
@floop WorkStealingEx() for (le, re) in zip(leftenv, rightenv)
t = ∂C(x, le, re)
@reduce(y = inplace_add!(nothing, t))
end
else
y = ∂C(x, leftenv[1], rightenv[1])
for (le, re) in zip(leftenv[2:end], rightenv[2:end])
VectorInterface.add!(y, ∂C(x, le, re))
end
end

return y
end

function ∂C(x::MPSBondTensor, leftenv::MPSTensor, rightenv::MPSTensor)
@plansor y[-1; -2] := leftenv[-1 3; 1] * x[1; 2] * rightenv[2 3; -2]
return y isa BlockTensorMap ? only(y) : y
Expand Down
14 changes: 2 additions & 12 deletions src/algorithms/excitation/quasiparticleexcitation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,19 +247,9 @@ function effective_excitation_hamiltonian(H::Union{InfiniteMPOHamiltonian,
envs.leftenvs,
envs.rightenvs))
ϕ′ = similar(ϕ)
@static if Defaults.parallelize_sites
@sync for loc in 1:length(ϕ)
Threads.@spawn begin
ϕ′[loc] = _effective_excitation_local_apply(loc, ϕ, H, energy[loc],
envs)
end
end
else
for loc in 1:length(ϕ)
ϕ′[loc] = _effective_excitation_local_apply(loc, ϕ, H, energy[loc], envs)
end
tforeach(1:length(ϕ); scheduler=Defaults.scheduler[]) do loc
return ϕ′[loc] = _effective_excitation_local_apply(loc, ϕ, H, energy[loc], envs)
end

return ϕ′
end

Expand Down
51 changes: 34 additions & 17 deletions src/algorithms/grassmann.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ module GrassmannMPS

using ..MPSKit
using TensorKit
using OhMyThreads
import TensorKitManifolds.Grassmann

function TensorKit.rmul!(a::Grassmann.GrassmannTangent, b::AbstractTensorMap)
Expand Down Expand Up @@ -67,23 +68,38 @@ struct ManifoldPoint{T,E,G,C}
Rhoreg::C # the regularized density matrices
end

function ManifoldPoint(state::Union{InfiniteMPS,FiniteMPS}, envs)
function ManifoldPoint(state::FiniteMPS, envs)
al_d = similar(state.AL)
O = envs.operator
for i in 1:length(state)
al_d[i] = MPSKit.∂∂AC(i, state, O, envs) * state.AC[i]
end

g = Grassmann.project.(al_d, state.AL)

Rhoreg = Vector{eltype(state.C)}(undef, length(state))
δmin = sqrt(eps(real(scalartype(state))))
for i in 1:length(state)
Rhoreg[i] = regularize(state.C[i], max(norm(g[i]) / 10, δmin))
tmap!(Rhoreg, 1:length(state); scheduler=MPSKit.Defaults.scheduler[]) do i
return regularize(state.C[i], max(norm(g[i]) / 10, δmin))
end

return ManifoldPoint(state, envs, g, Rhoreg)
end
function ManifoldPoint(state::InfiniteMPS, envs)
δmin = sqrt(eps(real(scalartype(state))))
Tg = Core.Compiler.return_type(Grassmann.project,
Tuple{eltype(state.AL),eltype(state.AL)})
g = similar(state.AL, Tg)
ρ = similar(state.C)

MPSKit.check_recalculate!(envs, state)
tforeach(1:length(state); scheduler=MPSKit.Defaults.scheduler[]) do i
AC′ = MPSKit.∂∂AC(i, state, envs.operator, envs) * state.AC[i]
g[i] = Grassmann.project(AC′, state.AL[i])
ρ[i] = regularize(state.C[i], max(norm(g[i]) / 10, δmin))
return nothing
end
return ManifoldPoint(state, envs, g, ρ)
end

function ManifoldPoint(state::MultilineMPS, envs)
# FIXME: add support for unitcells
Expand Down Expand Up @@ -115,10 +131,10 @@ cell as tangent vectors on Grassmann manifolds.
"""
function fg(x::ManifoldPoint{T}) where {T<:Union{InfiniteMPS,FiniteMPS}}
# the gradient I want to return is the preconditioned gradient!
g_prec = Vector{PrecGrad{eltype(x.g),eltype(x.Rhoreg)}}(undef, length(x.g))

for i in 1:length(x.state)
g_prec[i] = PrecGrad(rmul!(copy(x.g[i]), x.state.C[i]'), x.Rhoreg[i])
Tg = Core.Compiler.return_type(PrecGrad, Tuple{eltype(x.g),eltype(x.Rhoreg)})
g_prec = similar(x.g, Tg)
tmap!(g_prec, eachindex(x.g); scheduler=MPSKit.Defaults.scheduler[]) do i
return PrecGrad(rmul!(copy(x.g[i]), x.state.C[i]'), x.Rhoreg[i])
end

# TODO: the operator really should not be part of the environments, and this should
Expand Down Expand Up @@ -151,10 +167,10 @@ function retract(x::ManifoldPoint{<:MultilineMPS}, tg, alpha)
g = reshape(tg, size(x.state))

nal = similar(x.state.AL)
h = similar(g)
for (i, cg) in enumerate(tg)
(nal[i], th) = Grassmann.retract(x.state.AL[i], cg.Pg, alpha)
h[i] = PrecGrad(th)
h = similar(tg)
tmap!(h, eachindex(g); scheduler=MPSKit.Defaults.scheduler[]) do i
nal[i], th = Grassmann.retract(x.state.AL[i], g[i].Pg, alpha)
return PrecGrad(th)
end

nstate = MPSKit.MultilineMPS(nal, x.state.C[:, end])
Expand All @@ -171,9 +187,10 @@ function retract(x::ManifoldPoint{<:InfiniteMPS}, g, alpha)
envs = x.envs
nal = similar(state.AL)
h = similar(g) # The tangent at the end-point
for i in 1:length(g)

tmap!(h, eachindex(g); scheduler=MPSKit.Defaults.scheduler[]) do i
nal[i], th = Grassmann.retract(state.AL[i], g[i].Pg, alpha)
h[i] = PrecGrad(th)
return PrecGrad(th)
end

nstate = InfiniteMPS(nal, state.C[end])
Expand Down Expand Up @@ -209,9 +226,9 @@ Transport a tangent vector `h` along the retraction from `x` in direction `g` by
`alpha`. `xp` is the end-point of the retraction.
"""
function transport!(h, x, g, alpha, xp)
for i in 1:length(h)
h[i] = PrecGrad(Grassmann.transport!(h[i].Pg, x.state.AL[i], g[i].Pg, alpha,
xp.state.AL[i]))
tforeach(1:length(h); scheduler=MPSKit.Defaults.scheduler[]) do i
return h[i] = PrecGrad(Grassmann.transport!(h[i].Pg, x.state.AL[i], g[i].Pg, alpha,
xp.state.AL[i]))
end
return h
end
Expand Down
8 changes: 6 additions & 2 deletions src/algorithms/groundstate/gradient_grassmann.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@ struct GradientGrassmann{O<:OptimKit.OptimizationAlgorithm,F} <: Algorithm
finalize!::F

function GradientGrassmann(; method=ConjugateGradient, (finalize!)=OptimKit._finalize!,
tol=Defaults.tol, maxiter=Defaults.maxiter, verbosity=2)
tol=Defaults.tol, maxiter=Defaults.maxiter,
verbosity=Defaults.verbosity - 1)
if isa(method, OptimKit.OptimizationAlgorithm)
# We were given an optimisation method, just use it.
m = method
elseif method <: OptimKit.OptimizationAlgorithm
# We were given an optimisation method type, construct an instance of it.
m = method(; maxiter=maxiter, verbosity=verbosity, gradtol=tol)
# restrict linesearch maxiter
linesearch = OptimKit.HagerZhangLineSearch(; verbosity=verbosity - 2,
maxiter=100)
m = method(; maxiter, verbosity, gradtol=tol, linesearch)
else
msg = "method should be either an instance or a subtype of `OptimKit.OptimizationAlgorithm`."
throw(ArgumentError(msg))
Expand Down
21 changes: 7 additions & 14 deletions src/algorithms/groundstate/vumps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,15 @@
# initialization
ϵ::Float64 = calc_galerkin(ψ, envs)
temp_ACs = similar.(ψ.AC)
scheduler = Defaults.scheduler[]
log = IterLog("VUMPS")

LoggingExtras.withlevel(; alg.verbosity) do
@infov 2 loginit!(log, ϵ, sum(expectation_value(ψ, H, envs)))
for iter in 1:(alg.maxiter)
alg_eigsolve = updatetol(alg.alg_eigsolve, iter, ϵ)
@static if Defaults.parallelize_sites
@sync for loc in 1:length(ψ)
Threads.@spawn begin
temp_ACs[loc] = _vumps_localupdate(loc, ψ, H, envs, alg_eigsolve)
end
end
else
for loc in 1:length(ψ)
temp_ACs[loc] = _vumps_localupdate(loc, ψ, H, envs, alg_eigsolve)
end
tmap!(temp_ACs, 1:length(ψ); scheduler) do loc
return _vumps_localupdate(loc, ψ, H, envs, alg_eigsolve)
end

alg_gauge = updatetol(alg.alg_gauge, iter, ϵ)
Expand Down Expand Up @@ -76,7 +69,10 @@

function _vumps_localupdate(loc, ψ, H, envs, eigalg, factalg=QRpos())
local AC′, C′
@static if Defaults.parallelize_sites
if Defaults.scheduler[] isa SerialScheduler
_, AC′ = fixedpoint(∂∂AC(loc, ψ, H, envs), ψ.AC[loc], :SR, eigalg)
_, C′ = fixedpoint(∂∂C(loc, ψ, H, envs), ψ.C[loc], :SR, eigalg)

Check warning on line 74 in src/algorithms/groundstate/vumps.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/groundstate/vumps.jl#L73-L74

Added lines #L73 - L74 were not covered by tests
else
@sync begin
Threads.@spawn begin
_, AC′ = fixedpoint(∂∂AC(loc, ψ, H, envs), ψ.AC[loc], :SR, eigalg)
Expand All @@ -85,9 +81,6 @@
_, C′ = fixedpoint(∂∂C(loc, ψ, H, envs), ψ.C[loc], :SR, eigalg)
end
end
else
_, AC′ = fixedpoint(∂∂AC(loc, ψ, H, envs), ψ.AC[loc], :SR, eigalg)
_, C′ = fixedpoint(∂∂C(loc, ψ, H, envs), ψ.C[loc], :SR, eigalg)
end
return regauge!(AC′, C′; alg=factalg)
end
Loading
Loading