Skip to content

Commit 4147fe6

Browse files
antoine-levittniklasschmitz
authored andcommitted
Fix DFPT wrt temperature (#1156)
1 parent e14d746 commit 4147fe6

File tree

11 files changed

+113
-65
lines changed

11 files changed

+113
-65
lines changed

examples/polarizability.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ println("Polarizability : $polarizability")
105105

106106
## Multiply δVext times the Bloch waves, then solve the Dyson equation:
107107
δVψ = DFTK.multiply_ψ_by_blochwave(scfres.basis, scfres.ψ, δVext)
108-
res = DFTK.solve_ΩplusK_split(scfres, -δVψ; verbose=true)
108+
res = DFTK.solve_ΩplusK_split(scfres, δVψ; verbose=true)
109109

110110
# From the result of `solve_ΩplusK_split` we can easily compute the polarisabilities:
111111

src/Model.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,24 +92,27 @@ If you want to pass custom symmetry operations (e.g. a reduced or extended set)
9292
external potential breaks some of the passed symmetries. Use `false` to turn off
9393
symmetries completely.
9494
"""
95-
function Model(lattice::AbstractMatrix{T},
95+
function Model(lattice::AbstractMatrix{Tstatic},
9696
atoms::Vector{<:Element}=Element[],
97-
positions::Vector{<:AbstractVector}=Vec3{T}[];
97+
positions::Vector{<:AbstractVector}=Vec3{Tstatic}[];
9898
model_name="custom",
9999
εF=nothing,
100100
n_electrons::Union{Int,Nothing}=isnothing(εF) ?
101101
n_electrons_from_atoms(atoms) : nothing,
102102
# Force electrostatics with non-neutral cells; results not guaranteed.
103103
# Set to `true` by default for charged systems.
104104
disable_electrostatics_check=all(iszero, charge_ionic.(atoms)),
105-
magnetic_moments=T[],
105+
magnetic_moments=Tstatic[],
106106
terms=[Kinetic()],
107-
temperature=zero(T),
107+
temperature=zero(Tstatic),
108108
smearing=temperature > 0 ? Smearing.FermiDirac() : Smearing.None(),
109109
spin_polarization=determine_spin_polarization(magnetic_moments),
110110
symmetries=default_symmetries(lattice, atoms, positions, magnetic_moments,
111111
spin_polarization, terms),
112-
) where {T <: Real}
112+
) where {Tstatic <: Real}
113+
# # a bit convoluted because kwargs can't determine type parameters
114+
T = promote_type(Tstatic, typeof(temperature))
115+
113116
# Validate εF and n_electrons
114117
if !isnothing(εF) # fixed Fermi level
115118
if !isnothing(n_electrons)
@@ -250,9 +253,9 @@ function Model{T}(model::Model;
250253
Model(T.(lattice), atoms, positions;
251254
model.model_name,
252255
model.n_electrons,
253-
magnetic_moments,
256+
magnetic_moments=T.(magnetic_moments),
254257
terms=model.term_types,
255-
model.temperature,
258+
temperature=T(model.temperature),
256259
model.smearing,
257260
model.εF,
258261
model.spin_polarization,

src/Smearing.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,13 @@ function xlogx(x)
8383
end
8484
function entropy(S::FermiDirac, x)
8585
f = occupation(S, x)
86-
- (xlogx(f) + xlogx(1 - f))
86+
# protect against the occupation being exactly zero or one, which causes trouble with the derivative
87+
# this check is a bit stupid, but if we just check for f == 0, the branch won't get picked up by forwarddiff
88+
if abs(f) < eps(typeof(x))|| abs(1-f) < eps(typeof(x))
89+
zero(x)
90+
else
91+
- (xlogx(f) + xlogx(1 - f))
92+
end
8793
end
8894
function occupation_divided_difference(S::FermiDirac, x, y, εF, temp)
8995
temp == 0 && return occupation_divided_difference(None(), x, y, εF, temp)

src/postprocess/phonon.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ in reduced coordinates.
8282
isnothing(δHψs_αs) && continue
8383
# Response solver to get δψ
8484
(; δψ, δρ, δoccupation) = solve_ΩplusK_split(ham, ρ, ψ, occupation, εF, eigenvalues,
85-
-δHψs_αs; q, kwargs...)
85+
δHψs_αs; q, kwargs...)
8686
δoccupations[α, s] = δoccupation
8787
δρs[α, s] = δρ
8888
δψs[α, s] = δψ

src/postprocess/refine.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ function refine_scfres(scfres, basis_ref::PlaneWaveBasis{T};
150150
ΩpKe2 = apply_Ω(e2, ψr, hamr, Λ) .+ apply_K(basis_ref, e2, ψr, ρr, occ)
151151
ΩpKe2 = transfer_blochwave(ΩpKe2, basis_ref, basis)
152152

153-
rhs = resLF - ΩpKe2
153+
rhs = ΩpKe2 - resLF
154154

155155
# Invert Ω+K on the small space
156156
ΩpK_res = solve_ΩplusK(basis, ψ, rhs, occ; tol, kwargs...)

src/response/chi0.jl

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -275,36 +275,37 @@ The derivatives of the occupations are in-place stored in δocc.
275275
The tuple (; δocc, δεF) is returned. It is assumed the passed `δocc`
276276
are initialised to zero.
277277
"""
278-
function compute_δocc!(δocc, basis::PlaneWaveBasis{T}, ψ, εF, ε, δHψ) where {T}
278+
function compute_δocc!(δocc, basis::PlaneWaveBasis{T}, ψ, εF, ε, δHψ, δtemperature) where {T}
279279
model = basis.model
280280
temperature = model.temperature
281281
smearing = model.smearing
282282
filled_occ = filled_occupation(model)
283283

284-
# δocc = fn' * (δεn - δεF)
284+
# compute the derivative of
285+
# occ[k][n] = filled_occ*occupation((εnk-εF)/T)
285286
δεF = zero(T)
286287
if !is_effective_insulator(basis, ε, εF; smearing, temperature)
287288
# First compute δocc without self-consistent Fermi δεF.
288289
D = zero(T)
289290
for ik = 1:length(basis.kpoints), (n, εnk) in enumerate(ε[ik])
290-
enred = (εnk - εF) / temperature
291291
δεnk = real(dot(ψ[ik][:, n], δHψ[ik][:, n]))
292-
fpnk = filled_occ * Smearing.occupation_derivative(smearing, enred) / temperature
293-
δocc[ik][n] = δεnk * fpnk
294-
D += fpnk * basis.kweights[ik]
292+
εnkred = (εnk - εF) / temperature
293+
δεnkred = δεnk/temperature - εnkred*δtemperature/temperature
294+
fpnk = filled_occ * Smearing.occupation_derivative(smearing, εnkred)
295+
δocc[ik][n] = fpnk * δεnkred
296+
D -= fpnk * basis.kweights[ik] / temperature # while we're at it, accumulate the total DOS D
295297
end
296-
D = mpi_sum(D, basis.comm_kpts) # equal to minus the total DOS
298+
D = mpi_sum(D, basis.comm_kpts)
297299

298300
if isnothing(model.εF) # εF === nothing means that Fermi level is fixed by model
299-
# Compute δεF…
301+
# Compute δεF from δ ∑ occ = 0
300302
δocc_tot = mpi_sum(sum(basis.kweights .* sum.(δocc)), basis.comm_kpts)
301-
δεF = δocc_tot / D
303+
δεF = -δocc_tot / D
302304

303-
# … and recompute δocc, taking into account δεF.
305+
# … and add the corresponding contribution to δocc
304306
for ik = 1:length(basis.kpoints), (n, εnk) in enumerate(ε[ik])
305-
enred = (εnk - εF) / temperature
306-
fpnk = filled_occ * Smearing.occupation_derivative(smearing, enred) / temperature
307-
δocc[ik][n] -= fpnk * δεF
307+
fpnk = filled_occ * Smearing.occupation_derivative(smearing, (εnk - εF) / temperature)
308+
δocc[ik][n] -= fpnk * δεF / temperature
308309
end
309310
end
310311
end
@@ -396,6 +397,7 @@ Compute the orbital and occupation changes as a result of applying the ``χ_0``
396397
to the Hamiltonian change `δH` represented by the matrix-vector products `δHψ`.
397398
"""
398399
@views @timing function apply_χ0_4P(ham, ψ, occupation, εF, eigenvalues, δHψ;
400+
δtemperature=zero(eltype(ham.basis)),
399401
occupation_threshold, q=zero(Vec3{eltype(ham.basis)}),
400402
bandtolalg, tol=1e-9, kwargs_sternheimer...)
401403
basis = ham.basis
@@ -436,10 +438,11 @@ to the Hamiltonian change `δH` represented by the matrix-vector products `δHψ
436438
δoccupation = zero.(occupation)
437439
if iszero(q)
438440
δocc_occ = [δoccupation[ik][maskk] for (ik, maskk) in enumerate(mask_occ)]
439-
(; δεF) = compute_δocc!(δocc_occ, basis, ψ_occ, εF, ε_occ, δHψ_minus_q_occ)
441+
(; δεF) = compute_δocc!(δocc_occ, basis, ψ_occ, εF, ε_occ, δHψ_minus_q_occ, δtemperature)
440442
else
441443
# When δH is not periodic, δH ψnk is a Bloch wave at k+q and ψnk at k,
442444
# so that δεnk = <ψnk|δH|ψnk> = 0 and there is no occupation shift
445+
@assert δtemperature == 0 # TODO think about this
443446
δεF = zero(εF)
444447
end
445448

@@ -466,6 +469,7 @@ Parameters:
466469
- `maxiter`: Maximal number of CG iterations per k and band for Sternheimer
467470
"""
468471
function apply_χ0(ham, ψ, occupation, εF::T, eigenvalues, δV::AbstractArray{TδV};
472+
δtemperature=zero(eltype(ham.basis)),
469473
occupation_threshold=default_occupation_threshold(TδV),
470474
q=zero(Vec3{eltype(ham.basis)}),
471475
bandtolalg=BandtolBalanced(ham.basis, ψ, occupation; occupation_threshold),
@@ -493,7 +497,7 @@ function apply_χ0(ham, ψ, occupation, εF::T, eigenvalues, δV::AbstractArray{
493497
# δHψ_k = δV_{q} · ψ_{k-q}.
494498
δHψ = multiply_ψ_by_blochwave(basis, ψ, δV, q)
495499
res = apply_χ0_4P(ham, ψ, occupation, εF, eigenvalues, δHψ;
496-
occupation_threshold, q, bandtolalg,
500+
δtemperature, occupation_threshold, q, bandtolalg,
497501
kwargs_sternheimer...)
498502

499503
δρ = compute_δρ(basis, ψ, res.δψ, occupation, res.δoccupation; occupation_threshold, q)

src/response/hessian.jl

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,9 @@ end
103103

104104
"""
105105
Solve density-functional perturbation theory problem,
106-
that is return δψ where (Ω+K) δψ = rhs.
106+
that is return δψ where (Ω+K) δψ = -δHextψ.
107107
"""
108-
@timing function solve_ΩplusK(basis::PlaneWaveBasis{T}, ψ, rhs, occupation;
108+
@timing function solve_ΩplusK(basis::PlaneWaveBasis{T}, ψ, δHextψ, occupation;
109109
callback=ResponseCallback(), tol=1e-10) where {T}
110110
# for now, all orbitals have to be fully occupied -> need to strip them beforehand
111111
check_full_occupation(basis, occupation)
@@ -118,9 +118,9 @@ that is return δψ where (Ω+K) δψ = rhs.
118118
unpack(x) = unpack_ψ(reinterpret_complex(x), size.(ψ))
119119
unsafe_unpack(x) = unsafe_unpack_ψ(reinterpret_complex(x), size.(ψ))
120120

121-
# project rhs on the tangent space before starting
122-
proj_tangent!(rhs, ψ)
123-
rhs_pack = pack(rhs)
121+
# project δHextψ on the tangent space before starting
122+
proj_tangent!(δHextψ, ψ)
123+
δHextψ_pack = pack(δHextψ)
124124

125125
# preconditioner
126126
Pks = [PreconditionerTPA(basis, kpt) for kpt in basis.kpoints]
@@ -145,15 +145,15 @@ that is return δψ where (Ω+K) δψ = rhs.
145145
Ωδψ = apply_Ω(δψ, ψ, H, Λ)
146146
pack(Ωδψ + Kδψ)
147147
end
148-
J = LinearMap{T}(ΩpK, size(rhs_pack, 1))
148+
J = LinearMap{T}(ΩpK, size(δHextψ_pack, 1))
149149

150-
# solve (Ω+K) δψ = rhs on the tangent space with CG
150+
# solve (Ω+K) δψ = -δHextψ on the tangent space with CG
151151
function proj(x)
152152
δψ = unpack(x)
153153
proj_tangent!(δψ, ψ)
154154
pack(δψ)
155155
end
156-
res = cg(J, rhs_pack; precon=FunctionPreconditioner(f_ldiv!), proj, tol,
156+
res = cg(J, -δHextψ_pack; precon=FunctionPreconditioner(f_ldiv!), proj, tol,
157157
callback, comm=basis.comm_kpts)
158158
(; δψ=unpack(res.x), res.converged, res.tol, res.residual_norm,
159159
res.n_iter)
@@ -219,10 +219,10 @@ function (cb::OmegaPlusKDefaultCallback)(info)
219219
end
220220

221221
"""
222-
Solve the problem `(Ω+K) δψ = rhs` (density-functional perturbation theory)
223-
using a split algorithm, where `rhs` is typically
224-
`-δHextψ` (the negative matvec of an external perturbation with the SCF orbitals `ψ`) and
225-
`δψ` is the corresponding total variation in the orbitals `ψ`. Additionally returns:
222+
Solve the problem `(Ω+K) δψ = -δHextψ` (density-functional perturbation theory)
223+
using a split algorithm, where
224+
`δψ` is the total variation in the orbitals `ψ` corresponding to the external perturbation δHext.
225+
Additionally returns:
226226
- `δρ`: Total variation in density
227227
- `δHψ`: Total variation in Hamiltonian applied to orbitals
228228
- `δeigenvalues`: Total variation in eigenvalues
@@ -243,7 +243,8 @@ Input parameters:
243243
see [arxiv 2505.02319](https://arxiv.org/pdf/2505.02319) for more details.
244244
"""
245245
@timing function solve_ΩplusK_split(ham::Hamiltonian, ρ::AbstractArray{T}, ψ, occupation, εF,
246-
eigenvalues, rhs;
246+
eigenvalues, δHextψ;
247+
δtemperature=zero(real(T)),
247248
tol=1e-8, verbose=true,
248249
mixing=SimpleMixing(),
249250
occupation_threshold,
@@ -268,7 +269,7 @@ Input parameters:
268269
# = χ04P (-1 + E K2P (1 - χ02P K2P)^-1 R (-χ04P))
269270
# where χ02P = R χ04P E and K2P = R K E
270271
basis = ham.basis
271-
@assert size(rhs[1]) == size(ψ[1]) # Assume the same number of bands in ψ and rhs
272+
@assert size(δHextψ[1]) == size(ψ[1])
272273
start_ns = time_ns()
273274

274275
# TODO Better initial guess handling. Especially between the last iteration of the GMRES
@@ -281,10 +282,11 @@ Input parameters:
281282

282283
# compute δρ0 (ignoring interactions)
283284
δρ0 = let # Make sure memory owned by res0 is freed
284-
res0 = apply_χ0_4P(ham, ψ, occupation, εF, eigenvalues, -rhs;
285+
res0 = apply_χ0_4P(ham, ψ, occupation, εF, eigenvalues, δHextψ;
286+
δtemperature,
285287
maxiter=maxiter_sternheimer, tol=tol * factor_initial,
286288
bandtolalg, occupation_threshold,
287-
q, kwargs...) # = -χ04P * rhs
289+
q, kwargs...) # = χ04P * δHext
288290
callback((; stage=:noninteracting, runtime_ns=time_ns() - start_ns,
289291
Axinfos=[(; basis, tol=tol*factor_initial, res0...)]))
290292
compute_δρ(basis, ψ, res0.δψ, occupation, res0.δoccupation;
@@ -308,42 +310,46 @@ Input parameters:
308310
@warn "Solve_ΩplusK_split solver not converged"
309311
end
310312

311-
# Compute total change in Hamiltonian applied to ψ
313+
# Now we got δρ, but we're not done yet, because we want the full output of the four-point apply_χ0_4P,
314+
# so we redo an apply_χ0_4P
315+
316+
# Induced potential variation
312317
δVind = apply_kernel(basis, δρ; ρ, q) # Change in potential induced by δρ
313318

319+
# Total variation δHtot ψ
314320
# For phonon calculations, assemble
315321
# δHψ_k = δV_{q} · ψ_{k-q}.
316-
δHψ = multiply_ψ_by_blochwave(basis, ψ, δVind, q) .- rhs
317-
318-
# Compute total change in eigenvalues
319-
δeigenvalues = map(ψ, δHψ) do ψk, δHψk
320-
map(eachcol(ψk), eachcol(δHψk)) do ψnk, δHψnk
321-
real(dot(ψnk, δHψnk)) # δε_{nk} = <ψnk | δH | ψnk>
322-
end
323-
end
322+
δHtotψ = multiply_ψ_by_blochwave(basis, ψ, δVind, q) .+ δHextψ
324323

325324
# Compute final orbital response
326325
# TODO Here we just use what DFTK did before the inexact Krylov business, namely
327326
# a fixed Sternheimer tolerance of tol / 10. There are probably
328327
# smarter things one could do here
329-
resfinal = apply_χ0_4P(ham, ψ, occupation, εF, eigenvalues, δHψ;
328+
resfinal = apply_χ0_4P(ham, ψ, occupation, εF, eigenvalues, δHtotψ;
329+
δtemperature,
330330
maxiter=maxiter_sternheimer, tol=tol * factor_final,
331331
bandtolalg, occupation_threshold, q, kwargs...)
332332
callback((; stage=:final, runtime_ns=time_ns() - start_ns,
333333
Axinfos=[(; basis, tol=tol*factor_final, resfinal...)]))
334+
# Compute total change in eigenvalues
335+
δeigenvalues = map(ψ, δHtotψ) do ψk, δHtotψk
336+
map(eachcol(ψk), eachcol(δHtotψk)) do ψnk, δHtotψnk
337+
real(dot(ψnk, δHtotψnk)) # δε_{nk} = <ψnk | δHtot | ψnk>
338+
end
339+
end
334340

335-
(; resfinal.δψ, δρ, δHψ, δVind, δρ0, δeigenvalues, resfinal.δoccupation,
341+
(; resfinal.δψ, δρ, δHtotψ, δVind, δρ0, δeigenvalues, resfinal.δoccupation,
336342
resfinal.δεF, ε_adj, info_gmres)
337343
end
338344

339-
function solve_ΩplusK_split(scfres::NamedTuple, rhs; kwargs...)
345+
function solve_ΩplusK_split(scfres::NamedTuple, δHextψ; kwargs...)
340346
if (scfres.mixing isa KerkerMixing || scfres.mixing isa KerkerDosMixing)
341347
mixing = scfres.mixing
342348
else
343349
mixing = SimpleMixing()
344350
end
345351
solve_ΩplusK_split(scfres.ham, scfres.ρ, scfres.ψ, scfres.occupation,
346-
scfres.εF, scfres.eigenvalues, rhs;
352+
scfres.εF, scfres.eigenvalues, δHextψ;
347353
scfres.occupation_threshold, mixing,
348354
bandtolalg=BandtolBalanced(scfres), kwargs...)
349355
end

src/scf/newton.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function newton(basis::PlaneWaveBasis{T}, ψ0;
117117
# compute Newton step and next iteration
118118
res = compute_projected_gradient(basis, ψ, occupation)
119119
# solve (Ω+K) δψ = -res so that the Newton step is ψ <- ψ + δψ
120-
δψ = solve_ΩplusK(basis, ψ, -res, occupation; tol=tol_cg, callback=identity).δψ
120+
δψ = solve_ΩplusK(basis, ψ, res, occupation; tol=tol_cg, callback=identity).δψ
121121
ψ = [ortho_qr(ψ[ik] + δψ[ik]) for ik = 1:Nk]
122122

123123
ρout = compute_density(basis, ψ, occupation)

src/workarounds/forwarddiff_rules.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ function construct_value(model::Model{T}) where {T <: Dual}
166166
newpositions;
167167
model.model_name,
168168
model.n_electrons,
169-
magnetic_moments=[], # Symmetries given explicitly
169+
magnetic_moments=value_type(T)[], # Symmetries given explicitly
170170
terms=model.term_types,
171171
temperature=ForwardDiff.value(model.temperature),
172172
model.smearing,
@@ -233,12 +233,12 @@ function self_consistent_field(basis_dual::PlaneWaveBasis{T};
233233
scfres.εF).ham
234234
ham_dual * scfres.ψ
235235
end
236-
237236
# Implicit differentiation
238237
response.verbose && println("Solving response problem")
239238
δresults = ntuple(ForwardDiff.npartials(T)) do α
240239
δHextψ = [ForwardDiff.partials.(δHextψk, α) for δHextψk in Hψ_dual]
241-
solve_ΩplusK_split(scfres, -δHextψ;
240+
δtemperature = ForwardDiff.partials(basis_dual.model.temperature, α)
241+
solve_ΩplusK_split(scfres, δHextψ; δtemperature,
242242
tol=last(scfres.history_Δρ), response.verbose)
243243
end
244244

test/forwarddiff.jl

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,4 +402,33 @@ end
402402

403403
# Check that scfres_dual has the same parameters as scfres
404404
@test isempty(setdiff(keys(scfres), keys(scfres_dual)))
405-
end
405+
end
406+
407+
408+
@testitem "ForwardDiff wrt temperature" tags=[:dont_test_mpi, :minimal] begin
409+
using DFTK
410+
using ForwardDiff
411+
using LinearAlgebra
412+
using PseudoPotentialData
413+
414+
a = 10.26 # Silicon lattice constant in Bohr
415+
lattice = a / 2 * [[0 1 1.];
416+
[1 0 1.];
417+
[1 1 0.]]
418+
Si = ElementPsp(:Si, PseudoFamily("dojo.nc.sr.lda.v0_4_1.standard.upf"))
419+
atoms = [Si, Si]
420+
positions = [ones(3)/8, -ones(3)/8]
421+
422+
function get(T)
423+
model = model_DFT(lattice, atoms, positions; functionals=LDA(), temperature=T)
424+
basis = PlaneWaveBasis(model; Ecut=10, kgrid=[1, 1, 1])
425+
scfres = self_consistent_field(basis, tol=1e-12)
426+
scfres.energies.total
427+
end
428+
T0 = .01
429+
derivative_ε = let ε = 1e-5
430+
(get(T0+ε) - get(T0-ε)) / 2ε
431+
end
432+
derivative_fd = ForwardDiff.derivative(get, T0)
433+
@test norm(derivative_ε - derivative_fd) < 1e-4
434+
end

0 commit comments

Comments
 (0)