Skip to content

Commit 11b71c6

Browse files
committed
Accelerate energy minimization using Optim.Manifold feature (#453)
This implementation avoids the need for "CG restarts" that would reset the stereographic projection axis. As a consequence, energy minimization can sometimes converge much faster (speedups ranging from 30% and 1000%, depending on problem).
1 parent 253fc98 commit 11b71c6

File tree

10 files changed

+154
-218
lines changed

10 files changed

+154
-218
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Sunny"
22
uuid = "2b4a2ac8-8f8b-43e8-abf4-3cb0c45e8736"
3-
authors = ["The Sunny team"]
43
version = "0.8.0"
4+
authors = ["The Sunny team"]
55

66
[deps]
77
Brillouin = "23470ee3-d0df-4052-8b1a-8cbd6363e7f0"
@@ -11,6 +11,7 @@ ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4"
1111
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1212
HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49"
1313
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
14+
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
1415
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1516
MatInt = "f23b31af-f0ea-4208-b2e0-bbfc29c446c9"
1617
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
@@ -45,6 +46,7 @@ FFTW = "1.4.5"
4546
GLMakie = "0.13"
4647
HCubature = "1.7.0"
4748
JLD2 = "0.6.0"
49+
LineSearches = "7.4.1"
4850
LinearAlgebra = "1.10"
4951
Makie = "0.24"
5052
MatInt = "0.1.2"

docs/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
33
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
55
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
6-
IOCapture = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
76
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
87
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
98
Sunny = "2b4a2ac8-8f8b-43e8-abf4-3cb0c45e8736"

docs/src/versions.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
cell will error rather than give a wrong result.
1818
* When loading a CIF or mCIF, the precision parameter `symprec` becomes optional
1919
([#413](@ref)).
20-
* Tune [`minimize_energy!`](@ref) parameters for enhanced robustness. A small
21-
perturbation to the initial spin state breaks artificial symmetries
22-
([#442](@ref)). The return value becomes a struct that stores optimization
23-
statistics ([#430](@ref)).
20+
* Various enhancements to [`minimize_energy!`](@ref). The return value becomes a
21+
struct that stores optimization statistics ([#430](@ref)). A small
22+
perturbation to the initial spin state breaks accidental symmetries
23+
([#442](@ref)). Convergence to the local minimum becomes faster and more
24+
robust ([#453](@ref)).
2425
* Fixes to [`load_nxs`](@ref) ([#420](@ref)).
2526
* Add `interpolate` option to [`plot_intensities`](@ref). Selecting
2627
`interpolate=true` will significantly reduce file sizes of PDF exports

examples/09_Disorder_KPM.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ for (site1, site2, offset) in symmetry_equivalent_bonds(sys_inhom, Bond(1, 1, [1
6262
set_exchange_at!(sys_inhom, 1.0 + noise, site1, site2; offset)
6363
end
6464

65-
minimize_energy!(sys_inhom, maxiters=5_000)
65+
minimize_energy!(sys_inhom, maxiters=2_000)
6666
plot_spins(sys_inhom; color=[S[3] for S in sys_inhom.dipoles], ndims=2)
6767

6868
# Traditional spin wave theory with exact diagonalization becomes impractical
@@ -104,7 +104,7 @@ for site in eachsite(sys_inhom)
104104
sys_inhom.gs[site] = [1 0 0; 0 1 0; 0 0 1+noise]
105105
end
106106
randomize_spins!(sys_inhom)
107-
minimize_energy!(sys_inhom, maxiters=5_000)
107+
minimize_energy!(sys_inhom)
108108

109109
swt = SpinWaveTheoryKPM(sys_inhom; measure=ssf_perp(sys_inhom), tol=0.05)
110110
res = intensities(swt, path; energies, kernel)

src/Optimization.jl

Lines changed: 86 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -22,164 +22,122 @@ function Base.show(io::IO, res::MinimizationResult)
2222
end
2323

2424

25-
# Returns the stereographic projection u(α) = (2v + (1-v²)n)/(1+v²), which
26-
# involves the orthographic projection v = (1-nn̄)α. The input `n` must be
27-
# normalized. When `α=0`, the output is `u=n`, and when `|α|→ ∞` the output is
28-
# `u=-n`. In all cases, `|u|=1`.
29-
function stereographic_projection(α, n)
30-
@assert n'*n 1 || all(isnan, n)
31-
32-
v = α - n*(n'*α) # project out component parallel to `n`
33-
= real(v'*v)
34-
u = (2v + (1-v²)*n) / (1+v²) # stereographic projection
35-
return u
25+
# Manifold where the first nspins are normalized spheres of dimension
26+
# (nelems-1). Any remaining data is interpreted in the Euclidean metric.
27+
struct SpinManifold <: Optim.Manifold
28+
nelems :: Int # Number of scalar components in each spin
29+
nspins :: Int # Number of spins to normalize
3630
end
3731

38-
# Calculate the vector-Jacobian-product x̄ du(α)/dα, where
39-
# u(v) = (2v + (1-v²)n)/(1+v²), v(α,ᾱ)=Pα, and P=1-nn̄.
40-
#
41-
# From the chain rule for Wirtinger derivatives,
42-
# du/dα = (du/dv) (dv/dα) + (du/dv̄) (dv̄/dα) = du/dv P.
43-
#
44-
# In the second step, we used
45-
# dv/dα = P
46-
# dv̄/dα = conj(dv/dᾱ) = 0.
47-
#
48-
# The remaining Jacobian matrix is
49-
# du/dv = (2-2nv̄)/(1+v²) - 2(2v+(1-v²)n)/(1+v²)² v̄
50-
# = c - c[(1+cb)n + cv]v̄,
51-
# where b = (1-v²)/2 and c = 2/(1+v²).
52-
#
53-
# Using the above definitions, return:
54-
# x̄ du/dα = x̄ du/dv P
55-
#
56-
@inline function vjp_stereographic_projection(x̄, α, n)
57-
all(isnan, n) && return zero(n') # No gradient when α is fixed to zero
58-
59-
@assert n'*n 1
60-
61-
v = α - n*(n'*α)
62-
= real(v'*v)
63-
b = (1-v²)/2
64-
c = 2/(1+v²)
65-
# Perform dot products first to avoid constructing outer-product
66-
x̄_dudv = c*' - c * (x̄' * ((1+c*b)*n + c*v)) * v'
67-
# Apply projection P=1-nn̄ on right
68-
return x̄_dudv - (x̄_dudv * n) * n'
69-
end
70-
71-
# Returns v such that u = (2v + (1-v²)n)/(1+v²) and v⋅n = 0
72-
function inverse_stereographic_projection(u, n)
73-
all(isnan, n) && return zero(n) # NaN values denote α = v = zero
32+
optim_spin_view(sm::SpinManifold, x) = view(x, 1:(sm.nelems*sm.nspins))
7433

75-
@assert u'*u 1
76-
77-
uperp = u - n*(n'*u)
78-
uperp² = real(uperp' * uperp)
79-
s = sign(un)
80-
if isone(s) && uperp² < 1e-5
81-
c = 1/2 + uperp²/8 + uperp²*uperp²/16
82-
else
83-
c = (1 - s * sqrt(max(1 - uperp², 0))) / uperp²
34+
function Optim.retract!(sm::SpinManifold, x)
35+
x′ = reshape(optim_spin_view(sm, x), sm.nelems, :)
36+
for j in 1:sm.nspins
37+
xj = view(x′, :, j)
38+
xj ./= norm(xj)
8439
end
85-
return c * uperp
40+
return x
8641
end
8742

88-
function optim_set_spins!(sys::System{0}, αs, ns)
89-
αs = reinterpret(reshape, Vec3, αs)
90-
for site in eachsite(sys)
91-
s = stereographic_projection(αs[site], ns[site])
92-
set_dipole!(sys, s, site)
93-
end
94-
end
95-
function optim_set_spins!(sys::System{N}, αs, ns) where N
96-
αs = reinterpret(reshape, CVec{N}, αs)
97-
for site in eachsite(sys)
98-
Z = stereographic_projection(αs[site], ns[site])
99-
set_coherent!(sys, Z, site)
43+
function Optim.project_tangent!(sm::SpinManifold, g, x)
44+
x = reshape(optim_spin_view(sm, x), sm.nelems, :)
45+
g = reshape(optim_spin_view(sm, g), sm.nelems, :)
46+
for j in 1:sm.nspins
47+
xj = view(x, :, j)
48+
gj = view(g, :, j)
49+
gj .= gj .- xj .* ((xj' * gj) / norm2(xj))
10050
end
51+
return nothing
10152
end
10253

103-
function optim_set_gradient!(G, sys::System{0}, αs, ns)
104-
(αs, G) = reinterpret.(reshape, Vec3, (αs, G))
105-
set_energy_grad_dipoles!(G, sys.dipoles, sys) # G = dE/dS
106-
@. G *= norm(sys.dipoles) # G = dE/dS * dS/du = dE/du
107-
@. G = adjoint(vjp_stereographic_projection(G, αs, ns)) # G = dE/du du/dα = dE/dα
108-
end
109-
function optim_set_gradient!(G, sys::System{N}, αs, ns) where N
110-
(αs, G) = reinterpret.(reshape, CVec{N}, (αs, G))
111-
set_energy_grad_coherents!(G, sys.coherents, sys) # G = dE/dZ
112-
@. G *= norm(sys.coherents) # G = dE/dZ * dZ/du = dE/du
113-
@. G = adjoint(vjp_stereographic_projection(G, αs, ns)) # G = dE/du du/dα = dE/dα
54+
function optimize_with_restarts(; calc_f, calc_g!, x, method, maxiters, options_args)
55+
iters = 0
56+
while true
57+
options = Optim.Options(; iterations=maxiters-iters, options_args...)
58+
res = Optim.optimize(calc_f, calc_g!, x, method, options)
59+
x = Optim.minimizer(res)
60+
iters += Optim.iterations(res)
61+
if Optim.converged(res) || iters >= maxiters
62+
return (res, iters)
63+
end
64+
end
11465
end
11566

11667

11768
"""
118-
minimize_energy!(sys::System; maxiters=1000, method=Optim.ConjugateGradient(),
119-
kwargs...)
69+
minimize_energy!(sys::System; maxiters=1000, kwargs...)
12070
12171
Optimizes the spin configuration in `sys` to minimize energy. A total of
122-
`maxiters` iterations will be attempted. The `method` parameter will be used in
123-
the `optimize` function of the [Optim.jl
124-
package](https://github.com/JuliaNLSolvers/Optim.jl). Any remaining `kwargs`
125-
will be included in the `Options` constructor of Optim.jl.
72+
`maxiters` iterations will be attempted. Any remaining `kwargs` will be included
73+
in the `Options` constructor of the [Optim.jl
74+
package](https://github.com/JuliaNLSolvers/Optim.jl)
12675
12776
Convergence status is stored in the field `ret.converged` of the return value
12877
`ret`. Additional optimization statistics are stored in the field `ret.data`.
12978
"""
130-
function minimize_energy!(sys::System{N}; maxiters=1000, method=Optim.ConjugateGradient(),
131-
subiters=10, δ=1e-8, kwargs...) where N
132-
# Perturbation of sufficient magnitude to "almost surely" push away from an
133-
# unstable stationary point (e.g. local maximum or saddle).
79+
function minimize_energy!(sys::System{N}; maxiters=1000, δ=1e-8, kwargs...) where N
80+
# Small perturbation to destabilize an accidental stationary point (local
81+
# maximum or saddle).
13482
perturb_spins!(sys, δ)
13583

136-
# Allocate buffers for optimization:
137-
# - Each `ns[site]` defines a direction for stereographic projection.
138-
# - Each `αs[:,site]` will be optimized in the space orthogonal to `ns[site]`.
139-
if iszero(N)
140-
ns = normalize.(sys.dipoles)
141-
αs = zeros(Float64, 3, size(sys.dipoles)...)
84+
# Optimization variables are normalized spins or coherent states. In case of
85+
# a vacancy, use an arbitrary representative on the sphere: [1, 1, …] / √N.
86+
normalize_or_fallback(x) = normalize(iszero(x) ? one.(x) : x)
87+
x = if iszero(N)
88+
collect(vec(reinterpret(Float64, normalize_or_fallback.(sys.dipoles))))
14289
else
143-
ns = normalize.(sys.coherents)
144-
αs = zeros(ComplexF64, N, size(sys.coherents)...)
90+
collect(vec(reinterpret(ComplexF64, normalize_or_fallback.(sys.coherents))))
91+
end
92+
93+
# Load spins into system
94+
function load_spins!(x)
95+
if iszero(N)
96+
x = reinterpret(Vec3, x)
97+
for (i, site) in enumerate(eachsite(sys))
98+
set_dipole!(sys, x[i], site)
99+
end
100+
else
101+
x = reinterpret(CVec{N}, x)
102+
for (i, site) in enumerate(eachsite(sys))
103+
set_coherent!(sys, x[i], site)
104+
end
105+
end
145106
end
146107

147-
# Functions to calculate energy and gradient for the state `αs`
148-
function f(αs)
149-
optim_set_spins!(sys, αs, ns)
108+
# Energy and gradient callback functions
109+
function calc_f(x)
110+
load_spins!(x)
150111
return energy(sys)
151112
end
152-
function g!(G, αs)
153-
optim_set_spins!(sys, αs, ns)
154-
optim_set_gradient!(G, sys, αs, ns)
113+
function calc_g!(g, x)
114+
load_spins!(x)
115+
if iszero(N)
116+
g = reshape(reinterpret(Vec3, g), size(sys.dipoles))
117+
set_energy_grad_dipoles!(g, sys.dipoles, sys)
118+
@. g *= norm(sys.dipoles) # Sensitivity to change in unit spin
119+
else
120+
g = reshape(reinterpret(CVec{N}, g), size(sys.coherents))
121+
set_energy_grad_coherents!(g, sys.coherents, sys)
122+
@. g *= norm(sys.coherents) # Sensitivity to change in unit ket
123+
end
124+
return nothing
155125
end
156126

157-
# Repeatedly optimize using a small number (`subiters`) of steps, within
158-
# which the stereographic projection axes are fixed. Because we require
159-
# high-precision in the spin variables (x), disable check on the energy (f),
160-
# which is much lower precision. Disable check on x_reltol because x=[zeros]
161-
# is a valid configuration (all spins aligned with the stereographic
162-
# projection axes). Optim interprets x_abstol and g_abstol in the p=Inf norm
163-
# (largest vector component). Note that x is dimensionless and g=dE/dx has
164-
# energy units.
127+
# Disable check on the energy f, because we require high precision in the
128+
# dimensionless spin variables x. The checks x_abstol and g_abstol are in
129+
# the p=Inf norm (largest vector component).
165130
x_abstol = 1e-12
166131
g_abstol = 1e-12 * characteristic_energy_scale(sys)
167-
options = Optim.Options(; iterations=subiters, g_abstol, x_abstol, x_reltol=NaN, f_reltol=NaN, f_abstol=NaN, kwargs...)
168-
local res
169-
for iter in 1 : div(maxiters, subiters, RoundUp)
170-
res = Optim.optimize(f, g!, αs, method, options)
171-
172-
if Optim.converged(res)
173-
cnt = (iter-1)*subiters + res.iterations
174-
return MinimizationResult(true, cnt, res)
175-
end
176-
177-
# Reset stereographic projection based on current state
178-
ns .= normalize.(iszero(N) ? sys.dipoles : sys.coherents)
179-
αs .*= 0
132+
manifold = SpinManifold(iszero(N) ? 3 : N, length(eachsite(sys)))
133+
method = Optim.ConjugateGradient(; alphaguess=LineSearches.InitialHagerZhang(; αmax=10.0), manifold)
134+
options_args = (; g_abstol, x_abstol, x_reltol=NaN, f_reltol=NaN, f_abstol=NaN, kwargs...)
135+
(res, iters) = optimize_with_restarts(; calc_f, calc_g!, x, method, maxiters, options_args)
136+
137+
load_spins!(Optim.minimizer(res))
138+
mr = MinimizationResult(Optim.converged(res), iters, res)
139+
if !mr.converged
140+
@warn repr("text/plain", mr)
180141
end
181-
182-
mr = MinimizationResult(false, maxiters, res)
183-
@warn repr("text/plain", mr)
184142
return mr
185143
end

0 commit comments

Comments
 (0)