Skip to content

Commit d2b7359

Browse files
authored
Option to specify timeout for SCFs (#948)
1 parent 5049ffb commit d2b7359

File tree

5 files changed

+21
-15
lines changed

5 files changed

+21
-15
lines changed

ext/DFTKJLD2Ext.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ DFTK.make_subdict!(jld::Union{JLD2.Group,JLD2.JLDFile}, name::AbstractString) =
88
function save_jld2(to_dict_function!, file::AbstractString, scfres::NamedTuple;
99
save_ψ=true, save_ρ=true, extra_data=Dict{String,Any}(), compress=false)
1010
if mpi_master()
11-
JLD2.jldopen(file, "w"; compress) do jld
11+
JLD2.jldopen(file * ".new", "w"; compress) do jld
1212
to_dict_function!(jld, scfres; save_ψ, save_ρ)
1313
for (k, v) in pairs(extra_data)
1414
jld[k] = v
@@ -19,6 +19,7 @@ function save_jld2(to_dict_function!, file::AbstractString, scfres::NamedTuple;
1919
delete!(jld, "kgrid")
2020
jld["kgrid"] = scfres.basis.kgrid # Save original kgrid datastructure
2121
end
22+
mv(file * ".new", file; force=true)
2223
else
2324
dummy = Dict{String,Any}()
2425
to_dict_function!(dummy, scfres; save_ψ)

ext/DFTKJSON3Ext.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ function save_json(todict_function, filename::AbstractString, scfres::NamedTuple
1313
data[k] = v
1414
end
1515
if mpi_master()
16-
open(filename, "w") do io
17-
JSON3.pretty(io, data)
16+
open(filename * ".new", "w") do io
17+
JSON3.write(io, data)
1818
end
19+
mv(filename * ".new", filename; force=true)
1920
end
2021
MPI.Barrier(MPI.COMM_WORLD)
2122
nothing

src/scf/scf_solvers.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,16 @@
44
# maxiter), where f(x) is the fixed-point map. It must return an
55
# object supporting res.sol and res.converged
66

7-
# TODO max_iter could go to the solver generator function arguments
8-
97
"""
108
Create a damped SCF solver updating the density as
119
`x = β * x_new + (1 - β) * x`
1210
"""
1311
function scf_damping_solver=0.2)
14-
function fp_solver(f, x0, max_iter; tol=1e-6)
12+
function fp_solver(f, x0, maxiter; tol=1e-6)
1513
β = convert(eltype(x0), β)
1614
converged = false
1715
x = copy(x0)
18-
for i = 1:max_iter
16+
for i = 1:maxiter
1917
x_new = f(x)
2018

2119
if norm(x_new - x) < tol
@@ -36,13 +34,13 @@ Create a simple anderson-accelerated SCF solver. `m` specifies the number
3634
of steps to keep the history of.
3735
"""
3836
function scf_anderson_solver(m=10; kwargs...)
39-
function anderson(f, x0, max_iter; tol=1e-6)
37+
function anderson(f, x0, maxiter; tol=1e-6)
4038
T = eltype(x0)
4139
x = x0
4240

4341
converged = false
4442
acceleration = AndersonAcceleration(; m, kwargs...)
45-
for n = 1:max_iter
43+
for n = 1:maxiter
4644
residual = f(x) - x
4745
converged = norm(residual) < tol
4846
converged && break
@@ -57,7 +55,7 @@ CROP-accelerated root-finding iteration for `f`, starting from `x0` and keeping
5755
a history of `m` steps. Optionally `warming` specifies the number of non-accelerated
5856
steps to perform for warming up the history.
5957
"""
60-
function CROP(f, x0, m::Int, max_iter::Int, tol::Real, warming=0)
58+
function CROP(f, x0, m::Int, maxiter::Int, tol::Real, warming=0)
6159
# CROP iterates maintain xn and fn (/!\ fn != f(xn)).
6260
# xtn+1 = xn + fn
6361
# ftn+1 = f(xtn+1)
@@ -70,7 +68,7 @@ function CROP(f, x0, m::Int, max_iter::Int, tol::Real, warming=0)
7068

7169
# Cheat support for multidimensional arrays
7270
if length(size(x0)) != 1
73-
x, conv= CROP(x -> vec(f(reshape(x, size(x0)...))), vec(x0), m, max_iter, tol, warming)
71+
x, conv= CROP(x -> vec(f(reshape(x, size(x0)...))), vec(x0), m, maxiter, tol, warming)
7472
return (; fixpoint=reshape(x, size(x0)...), converged=conv)
7573
end
7674
N = size(x0,1)
@@ -79,10 +77,10 @@ function CROP(f, x0, m::Int, max_iter::Int, tol::Real, warming=0)
7977
fs = zeros(T, N, m+1) # newest to oldest
8078
xs[:,1] = x0
8179
fs[:,1] = f(x0) # Residual
82-
errs = zeros(max_iter)
80+
errs = zeros(maxiter)
8381
err = Inf
8482

85-
for n = 1:max_iter
83+
for n = 1:maxiter
8684
xtnp1 = xs[:, 1] + fs[:, 1] # Richardson update
8785
ftnp1 = f(xtnp1) # Residual
8886
err = norm(ftnp1)
@@ -112,4 +110,4 @@ function CROP(f, x0, m::Int, max_iter::Int, tol::Real, warming=0)
112110
end
113111
(; fixpoint=xs[:, 1], converged=err < tol)
114112
end
115-
scf_CROP_solver(m=10) = (f, x0, max_iter; tol=1e-6) -> CROP(x -> f(x) - x, x0, m, max_iter, tol)
113+
scf_CROP_solver(m=10) = (f, x0, maxiter; tol=1e-6) -> CROP(x -> f(x) - x, x0, m, maxiter, tol)

src/scf/self_consistent_field.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
include("scf_callbacks.jl")
2+
using Dates
23

34
"""
45
Transparently handle checkpointing by either returning kwargs for `self_consistent_field`,
@@ -109,6 +110,8 @@ Overview of parameters:
109110
- `is_converged`: Convergence control callback. Typical objects passed here are
110111
`ScfConvergenceDensity(tol)` (the default), `ScfConvergenceEnergy(tol)` or `ScfConvergenceForce(tol)`.
111112
- `maxiter`: Maximal number of SCF iterations
113+
- `maxtime`: Maximal time to run the SCF for. If this is reached without
114+
convergence, the SCF stops.
112115
- `mixing`: Mixing method, which determines the preconditioner ``P^{-1}`` in the above equation.
113116
Typical mixings are [`LdosMixing`](@ref), [`KerkerMixing`](@ref), [`SimpleMixing`](@ref)
114117
or [`DielectricMixing`](@ref). Default is `LdosMixing()`
@@ -129,6 +132,7 @@ Overview of parameters:
129132
tol=1e-6,
130133
is_converged=ScfConvergenceDensity(tol),
131134
maxiter=100,
135+
maxtime=Year(1),
132136
mixing=LdosMixing(),
133137
damping=0.8,
134138
solver=scf_anderson_solver(),
@@ -152,6 +156,7 @@ Overview of parameters:
152156
energies = nothing
153157
ham = nothing
154158
start_ns = time_ns()
159+
end_time = Dates.now() + maxtime
155160
info = (; n_iter=0, ρin=ρ) # Populate info with initial values
156161
history_Etot = T[]
157162
history_Δρ = T[]
@@ -161,6 +166,7 @@ Overview of parameters:
161166
# TODO support other mixing types
162167
function fixpoint_map(ρin)
163168
converged && return ρin # No more iterations if convergence flagged
169+
MPI.bcast(Dates.now() end_time, MPI.COMM_WORLD) && return ρin
164170
n_iter += 1
165171

166172
# Note that ρin is not the density of ψ, and the eigenvalues

test/scf_compare.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
# Run other SCFs with SAD guess
3535
ρ0 = guess_density(basis)
36-
for solver in (scf_anderson_solver(), scf_damping_solver(1.0), scf_CROP_solver())
36+
for solver in (scf_anderson_solver(), scf_damping_solver(), scf_CROP_solver())
3737
@testset "Testing $solver" begin
3838
ρ_alg = self_consistent_field(basis; ρ=ρ0, solver, tol).ρ
3939
@test maximum(abs, ρ_alg - ρ_def) < 50tol

0 commit comments

Comments
 (0)