Skip to content

Commit c061abf

Browse files
Technici4ngkemlin
authored andcommitted
Make SCF runs optionally reproducible by providing a seed (JuliaMolSim#1161)
1 parent 85cda3e commit c061abf

File tree

10 files changed

+70
-9
lines changed

10 files changed

+70
-9
lines changed

src/DFTK.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ include("common/hankel.jl")
4545
include("common/hydrogenic.jl")
4646
include("common/derivatives.jl")
4747
include("common/linalg.jl")
48+
include("common/random.jl")
4849

4950
export PspHgh
5051
export PspUpf

src/common/random.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""
2+
Seeds the task local RNG across all MPI ranks.
3+
A seed can be provided for reproducibility with previous runs
4+
(for the same Julia version and Manifest.toml).
5+
If no seed is provided, a random seed is generated on the master process.
6+
The returned seed can be used to reproduce the run.
7+
8+
If any subtask is spawned, it will be seeded based on the task local RNG of its parent,
9+
as explained in the documentation of `Random.TaskLocalRNG`.
10+
Seeding the task local RNG at the beginning of a computation is thus sufficient.
11+
"""
12+
function seed_task_local_rng!(seed::Union{Nothing,Integer}, comm)
13+
if mpi_master(comm) && isnothing(seed)
14+
# Using negative seeds requires Julia 1.11 and DFTK still supports 1.10
15+
seed = rand(UInt64)
16+
end
17+
seed = MPI.bcast(seed, comm)
18+
if mpi_master(comm)
19+
Random.seed!(seed)
20+
# Generate a different seed for each process
21+
local_seeds = rand(typeof(seed), mpi_nprocs(comm))
22+
local_seed = MPI.scatter(local_seeds, comm)
23+
else
24+
local_seed = MPI.scatter(nothing, comm)
25+
end
26+
Random.seed!(local_seed)
27+
seed
28+
end

src/postprocess/band_structure.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ All kwargs not specified below are passed to [`diagonalize_all_kblocks`](@ref):
1515
kgrid::Union{AbstractKgrid,AbstractKgridGenerator};
1616
n_bands=default_n_bands_bandstructure(basis.model),
1717
n_extra=3, ρ=nothing, τ=nothing, εF=nothing,
18-
eigensolver=lobpcg_hyper, tol=1e-3, kwargs...)
18+
eigensolver=lobpcg_hyper, tol=1e-3, seed=nothing,
19+
kwargs...)
1920
# kcoords are the kpoint coordinates in fractional coordinates
2021
if isnothing(ρ)
2122
if any(t isa TermNonlinear for t in basis.terms)
@@ -30,6 +31,7 @@ All kwargs not specified below are passed to [`diagonalize_all_kblocks`](@ref):
3031
"quantity to compute_bands as the τ keyword argument or use the " *
3132
"compute_bands(scfres) function.")
3233
end
34+
seed = seed_task_local_rng!(seed, MPI.COMM_WORLD)
3335

3436
# Create new basis with new kpoints
3537
bs_basis = PlaneWaveBasis(basis, kgrid)
@@ -54,7 +56,7 @@ All kwargs not specified below are passed to [`diagonalize_all_kblocks`](@ref):
5456
# types subtype. In a first version the ScfResult could just contain
5557
# the currently used named tuple and forward all operations to it.
5658
(; basis=bs_basis, ψ=eigres.X, eigenvalues=eigres.λ, ρ, εF, occupation,
57-
diagonalization=[eigres])
59+
diagonalization=[eigres], seed)
5860
end
5961

6062
"""

src/scf/direct_minimization.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,13 @@ function direct_minimization(basis::PlaneWaveBasis{T};
7575
optim_method=Optim.LBFGS,
7676
alphaguess=LineSearches.InitialStatic(),
7777
linesearch=LineSearches.BackTracking(),
78+
seed=nothing,
7879
kwargs...) where {T}
7980
if mpi_nprocs() > 1
8081
# need synchronization in Optim
8182
error("Direct minimization with MPI is not supported yet")
8283
end
84+
seed = seed_task_local_rng!(seed, MPI.COMM_WORLD)
8385
model = basis.model
8486
@assert iszero(model.temperature) # temperature is not yet supported
8587
@assert isnothing(model.εF) # neither are computations with fixed Fermi level
@@ -189,7 +191,7 @@ function direct_minimization(basis::PlaneWaveBasis{T};
189191
# We rely on the fact that the last point where fg! was called is the minimizer to
190192
# avoid recomputing at ψ
191193
info = (; ham, basis, energies, converged, ρ, eigenvalues, occupation, εF,
192-
n_bands_converge=n_bands, n_iter=Optim.iterations(res),
194+
n_bands_converge=n_bands, n_iter=Optim.iterations(res), seed,
193195
runtime_ns=time_ns() - start_ns, history_Δρ, history_Etot,
194196
ψ, stage=:finalize, algorithm="DM", optim_res=res)
195197
callback(info)

src/scf/newton.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,10 @@ from the solution.
8181
function newton(basis::PlaneWaveBasis{T}, ψ0;
8282
tol=1e-6, tol_cg=tol / 100, maxiter=20,
8383
callback=ScfDefaultCallback(),
84-
is_converged=ScfConvergenceDensity(tol)) where {T}
84+
is_converged=ScfConvergenceDensity(tol),
85+
seed=nothing) where {T}
8586

87+
seed = seed_task_local_rng!(seed, MPI.COMM_WORLD)
8688
# setting parameters
8789
model = basis.model
8890
@assert iszero(model.temperature) # temperature is not yet supported
@@ -149,7 +151,7 @@ function newton(basis::PlaneWaveBasis{T}, ψ0;
149151
# return results and call callback one last time with final state for clean
150152
# up
151153
info = (; ham=H, basis, energies, converged, ρ, eigenvalues, occupation, εF, n_iter, ψ,
152-
stage=:finalize, algorithm="Newton", runtime_ns=time_ns() - start_ns)
154+
stage=:finalize, algorithm="Newton", seed, runtime_ns=time_ns() - start_ns)
153155
callback(info)
154156
info
155157
end

src/scf/potential_mixing.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ Simple SCF algorithm using potential mixing. Parameters are largely the same as
175175
acceleration=AndersonAcceleration(;m=10),
176176
accept_step=ScfAcceptStepAll(),
177177
max_backtracks=3, # Maximal number of backtracking line searches
178+
seed=nothing,
178179
)
179180
# TODO Test other mixings and lift this
180181
@assert ( mixing isa SimpleMixing
@@ -185,6 +186,7 @@ Simple SCF algorithm using potential mixing. Parameters are largely the same as
185186
if !isnothing(ψ)
186187
@assert length(ψ) == length(basis.kpoints)
187188
end
189+
seed = seed_task_local_rng!(seed, MPI.COMM_WORLD)
188190

189191
# Initial guess for V (if none given)
190192
ham = energy_hamiltonian(basis, nothing, nothing; ρ).ham
@@ -284,7 +286,7 @@ Simple SCF algorithm using potential mixing. Parameters are largely the same as
284286
info = (; ham, basis, info.energies, converged, ρ=info.ρout, info.eigenvalues,
285287
info.occupation, info.εF, n_iter, info.ψ, info.n_bands_converge,
286288
info.diagonalization, stage=:finalize, algorithm="SCF",
287-
history_Δρ, history_Etot, info.occupation_threshold,
289+
history_Δρ, history_Etot, info.occupation_threshold, seed,
288290
runtime_ns=time_ns() - start_ns)
289291
callback(info)
290292
info

src/scf/self_consistent_field.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,15 @@ Overview of parameters:
146146
fermialg::AbstractFermiAlgorithm=default_fermialg(basis.model),
147147
callback=ScfDefaultCallback(; show_damping=false),
148148
compute_consistent_energies=true,
149+
seed=nothing,
149150
response=ResponseOptions(), # Dummy here, only for AD
150151
) where {T}
151152
if !isnothing(ψ)
152153
@assert length(ψ) == length(basis.kpoints)
153154
end
154155
start_ns = time_ns()
155156
timeout_date = Dates.now() + maxtime
157+
seed = seed_task_local_rng!(seed, MPI.COMM_WORLD)
156158

157159
# We do density mixing in the real representation
158160
# TODO support other mixing types
@@ -182,7 +184,7 @@ Overview of parameters:
182184
# Update info with results gathered so far
183185
info_next = (; ham, basis, converged, stage=:iterate, algorithm="SCF",
184186
ρin, τ, α=damping, n_iter, nbandsalg.occupation_threshold,
185-
runtime_ns=time_ns() - start_ns, nextstate...,
187+
seed, runtime_ns=time_ns() - start_ns, nextstate...,
186188
diagonalization=[nextstate.diagonalization])
187189

188190
# Compute the energy of the new state
@@ -226,7 +228,7 @@ Overview of parameters:
226228
scfres = (; ham, basis, energies, converged, nbandsalg.occupation_threshold,
227229
ρ=ρout, τ, α=damping, eigenvalues, occupation, εF, info.n_bands_converge,
228230
info.n_iter, info.n_matvec, ψ, info.diagonalization, stage=:finalize,
229-
info.history_Δρ, info.history_Etot, info.timedout, mixing,
231+
info.history_Δρ, info.history_Etot, info.timedout, mixing, seed,
230232
runtime_ns=time_ns() - start_ns, algorithm="SCF")
231233
callback(scfres)
232234
scfres

src/workarounds/forwarddiff_rules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ function self_consistent_field(basis_dual::PlaneWaveBasis{T};
275275
scfres.converged, scfres.occupation_threshold, scfres.α, scfres.n_iter,
276276
scfres.n_bands_converge, scfres.n_matvec, scfres.diagonalization, scfres.stage,
277277
scfres.history_Δρ, scfres.history_Etot, scfres.timedout, scfres.mixing,
278-
scfres.algorithm, scfres.runtime_ns)
278+
scfres.seed, scfres.algorithm, scfres.runtime_ns)
279279
end
280280

281281
function hankel(r::AbstractVector, r2_f::AbstractVector, l::Integer, p::TT) where {TT <: ForwardDiff.Dual}

test/reproducibility.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
@testitem "Reproducibility of seeded SCF runs" setup=[TestCases] begin
2+
using DFTK
3+
silicon = TestCases.silicon
4+
5+
model = model_DFT(silicon.lattice, silicon.atoms, silicon.positions; functionals=LDA())
6+
Ecut = 15
7+
kgrid = [2, 2, 2]
8+
9+
basis = PlaneWaveBasis(model; Ecut, kgrid)
10+
scfres1 = self_consistent_field(basis; tol=1e-7)
11+
12+
# Use seed from scfres1 for reproducibility
13+
scfres2 = self_consistent_field(basis; tol=1e-7, scfres1.seed)
14+
15+
# Should be exactly equal if the computation is reproducible, no need for epsilons.
16+
@assert scfres1.history_Etot == scfres2.history_Etot
17+
@assert scfres1.history_Δρ == scfres2.history_Δρ
18+
@assert scfres1.ψ == scfres2.ψ
19+
@assert scfres1.ρ == scfres2.ρ
20+
end

test/serialisation.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ function test_scfres_agreement(tested, ref; test_ψ=true)
4040
if test_ψ
4141
@test tested.ψ == ref.ψ
4242
end
43+
44+
@test tested.seed == ref.seed
4345
end
4446
end
4547

0 commit comments

Comments
 (0)