Skip to content

Commit 4da1237

Browse files
shivang57721niklasschmitzmfherbstTechnici4n
authored
Include missing SCF metadata fields in forwarddiff_rules.jl (#1133)
* Include missing SCF metadata fields in forwarddiff_rules.jl * Removed extra line * Added test to verify that scfres dual has the same parameters as scfres primal * Removed extra tag in test * Simplify scfres dual params test * Use ForwardDiff.Tag(:mytag, Float64) function (which will increment ForwardDiff's internal global Tag counter) instead of directly using ForwardDiff.Tag{...} type constructor * Fix symmetries & symmetry-breaking with ForwardDiff (#1082) --------- Co-authored-by: Bruno Ploumhans <[email protected]> * Added test to verify that scfres dual has the same parameters as scfres primal * Removed extra tag in test * Remove duplicate tau * Re-align named tuple fields --------- Co-authored-by: Niklas Schmitz <[email protected]> Co-authored-by: Michael F. Herbst <[email protected]> Co-authored-by: Bruno Ploumhans <[email protected]>
1 parent 331e873 commit 4da1237

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

src/workarounds/forwarddiff_rules.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,10 +269,12 @@ function self_consistent_field(basis_dual::PlaneWaveBasis{T};
269269

270270
# This has to be changed whenever the scfres structure changes
271271
(; ham, basis=basis_dual, energies, ρ, eigenvalues, occupation, εF, ψ,
272+
scfres.τ, # TODO make τ also differentiable for meta-GGA DFPT
272273
# non-differentiable metadata:
273274
response=getfield.(δresults, :info_gmres),
274275
scfres.converged, scfres.occupation_threshold, scfres.α, scfres.n_iter,
275-
scfres.n_bands_converge, scfres.diagonalization, scfres.stage,
276+
scfres.n_bands_converge, scfres.n_matvec, scfres.diagonalization, scfres.stage,
277+
scfres.history_Δρ, scfres.history_Etot, scfres.timedout, scfres.mixing,
276278
scfres.algorithm, scfres.runtime_ns)
277279
end
278280

test/forwarddiff.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,4 +373,33 @@ end
373373
rtol = 1e-4
374374
@test norm(δρ - δρ_finitediff, 1) < rtol * norm(δρ, 1)
375375
end
376+
end
377+
378+
@testitem "Test scfres dual has the same params as scfres primal" #=
379+
=# tags=[:dont_test_mpi] setup=[TestCases] begin
380+
using DFTK
381+
using ForwardDiff
382+
using LinearAlgebra
383+
using PseudoPotentialData
384+
silicon = TestCases.silicon
385+
386+
# Make silicon primal model
387+
model = model_DFT(silicon.lattice, silicon.atoms, silicon.positions;
388+
functionals=LDA(), temperature=1e-3, smearing=Smearing.Gaussian())
389+
390+
# Make silicon dual model
391+
T = typeof(ForwardDiff.Tag(:mytag, Float64))
392+
x_dual = ForwardDiff.Dual{T}(1.0, 1.0)
393+
model_dual = Model(model; lattice=x_dual * model.lattice)
394+
395+
# Construct the primal and dual basis
396+
basis = PlaneWaveBasis(model; Ecut=5, kgrid=(1,1,1))
397+
basis_dual = PlaneWaveBasis(model_dual; Ecut=5, kgrid=(1,1,1))
398+
399+
# Compute scfres with primal and dual basis
400+
scfres = self_consistent_field(basis; tol=1e-5)
401+
scfres_dual = self_consistent_field(basis_dual; tol=1e-5)
402+
403+
# Check that scfres_dual has the same parameters as scfres
404+
@test isempty(setdiff(keys(scfres), keys(scfres_dual)))
376405
end

0 commit comments

Comments
 (0)