1+ # Wrappers around ForwardDiff to fix tags and reduce compilation time.
2+ # Warning: fixing types without care can lead to perturbation confusion,
3+ # this should only be used within the testing framework. Risk of perturbation
4+ # confusion arises when nested derivatives of different functions are taken
5+ # with a fixed tag. Only use these wrappers at the top-level call.
6+ @testmodule ForwardDiffWrappers begin
7+ using ForwardDiff
8+ export tagged_derivative, tagged_gradient, tagged_jacobian
9+
10+ struct DerivativeTag end
11+ function tagged_derivative (f, x:: T ; custom_tag= DerivativeTag) where T
12+ # explicit call to ForwardDiff.Tag() to trigger ForwardDiff.tagcount
13+ TagType = typeof (ForwardDiff. Tag (custom_tag (), T))
14+ x_dual = ForwardDiff. Dual {TagType, T, 1} (x, ForwardDiff. Partials ((one (T),)))
15+
16+ res = ForwardDiff. extract_derivative (TagType, f (x_dual))
17+ return res
18+ end
19+
20+ struct GradientTag end
21+ GradientConfig (f, x, :: Type{Tag} ) where {Tag} =
22+ ForwardDiff. GradientConfig (f, x, ForwardDiff. Chunk (x), Tag ())
23+ function tagged_gradient (f, x:: AbstractArray{T} ; custom_tag= GradientTag) where T
24+ # explicit call to ForwardDiff.Tag() to trigger ForwardDiff.tagcount
25+ TagType = typeof (ForwardDiff. Tag (custom_tag (), T))
26+
27+ cfg = GradientConfig (f, x, TagType)
28+ ForwardDiff. gradient (f, x, cfg, Val {false} ())
29+ end
30+
31+ struct JacobianTag end
32+ JacobianConfig (f, x, :: Type{Tag} ) where {Tag} =
33+ ForwardDiff. JacobianConfig (f, x, ForwardDiff. Chunk (x), Tag ())
34+ function tagged_jacobian (f, x:: AbstractArray{T} ; custom_tag= JacobianTag) where T
35+ # explicit call to ForwardDiff.Tag() to trigger ForwardDiff.tagcount
36+ TagType = typeof (ForwardDiff. Tag (custom_tag (), T))
37+
38+ cfg = JacobianConfig (f, x, TagType)
39+ ForwardDiff. jacobian (f, x, cfg, Val {false} ())
40+ end
41+ end
42+
143@testitem " Force derivatives using ForwardDiff" #=
2- =# tags= [:dont_test_mpi , :minimal ] setup= [TestCases] begin
44+ =# tags= [:dont_test_mpi , :minimal ] setup= [TestCases, ForwardDiffWrappers ] begin
345 using DFTK
446 using ForwardDiff
547 using LinearAlgebra
2870 derivative_ε1_fd = let ε1 = 1e-5
2971 (compute_force (ε1, 0.0 ) - F) / ε1
3072 end
31- derivative_ε1 = ForwardDiff . derivative (ε1 -> compute_force (ε1, 0.0 ), 0.0 )
73+ derivative_ε1 = tagged_derivative (ε1 -> compute_force (ε1, 0.0 ), 0.0 )
3274 @test norm (derivative_ε1 - derivative_ε1_fd) < 1e-4
3375
3476 derivative_ε2_fd = let ε2 = 1e-5
3577 (compute_force (0.0 , ε2) - F) / ε2
3678 end
37- derivative_ε2 = ForwardDiff . derivative (ε2 -> compute_force (0.0 , ε2), 0.0 )
79+ derivative_ε2 = tagged_derivative (ε2 -> compute_force (0.0 , ε2), 0.0 )
3880 @test norm (derivative_ε2 - derivative_ε2_fd) < 1e-4
3981
4082 @testset " Multiple partials" begin
41- grad = ForwardDiff . gradient (v -> compute_force (v... )[1 ][1 ], [0.0 , 0.0 ])
83+ grad = tagged_gradient (v -> compute_force (v... )[1 ][1 ], [0.0 , 0.0 ])
4284 @test abs (grad[1 ] - derivative_ε1[1 ][1 ]) < 1e-4
4385 @test abs (grad[2 ] - derivative_ε2[1 ][1 ]) < 1e-4
4486
45- jac = ForwardDiff . jacobian (v -> compute_force (v... )[1 ], [0.0 , 0.0 ])
87+ jac = tagged_jacobian (v -> compute_force (v... )[1 ], [0.0 , 0.0 ])
4688 @test norm (grad - jac[1 , :]) < 1e-9
4789 end
4890
5193 derivative_ε1_fd = let ε1 = 1e-5
5294 (compute_force (ε1, 0.0 ; metal) - compute_force (- ε1, 0.0 ; metal)) / 2 ε1
5395 end
54- derivative_ε1 = ForwardDiff . derivative (ε1 -> compute_force (ε1, 0.0 ; metal), 0.0 )
96+ derivative_ε1 = tagged_derivative (ε1 -> compute_force (ε1, 0.0 ; metal), 0.0 )
5597 @test norm (derivative_ε1 - derivative_ε1_fd) < 1e-4
5698 end
5799
62104 derivative_ε1_fd = let ε1 = 1e-5
63105 (compute_force (ε1, 0.0 ; atoms) - compute_force (- ε1, 0.0 ; atoms)) / 2 ε1
64106 end
65- derivative_ε1 = ForwardDiff . derivative (ε1 -> compute_force (ε1, 0.0 ; atoms), 0.0 )
107+ derivative_ε1 = tagged_derivative (ε1 -> compute_force (ε1, 0.0 ; atoms), 0.0 )
66108 @test norm (derivative_ε1 - derivative_ε1_fd) < 1e-4
67109 end
68110end
69111
70112@testitem " Anisotropic strain sensitivity using ForwardDiff" #=
71- =# tags= [:dont_test_mpi , :minimal ] setup= [TestCases] begin
113+ =# tags= [:dont_test_mpi , :minimal ] setup= [TestCases, ForwardDiffWrappers ] begin
72114 using DFTK
73115 using ForwardDiff
74116 using LinearAlgebra
103145
104146 @testset " $strain_fn " for strain_fn in (strain_isotropic, strain_anisotropic)
105147 f (ε) = compute_properties (strain_fn (ε))
106- dx = ForwardDiff . derivative (f, 0. )
148+ dx = tagged_derivative (f, 0. )
107149
108150 h = 1e-4
109151 x1 = f (- h)
118160end
119161
120162@testitem " scfres PSP sensitivity using ForwardDiff" #=
121- =# tags= [:dont_test_mpi , :minimal ] setup= [TestCases] begin
163+ =# tags= [:dont_test_mpi , :minimal ] setup= [TestCases, ForwardDiffWrappers ] begin
122164 using DFTK
123165 using ForwardDiff
124166 using LinearAlgebra
@@ -155,12 +197,12 @@ end
155197 derivative_ε = let ε = 1e-4
156198 (compute_band_energies (ε) - compute_band_energies (- ε)) / 2 ε
157199 end
158- derivative_fd = ForwardDiff . derivative (compute_band_energies, 0.0 )
200+ derivative_fd = tagged_derivative (compute_band_energies, 0.0 )
159201 @test norm (derivative_fd - derivative_ε) < 5e-4
160202end
161203
162204@testitem " Functional force sensitivity using ForwardDiff" #=
163- =# tags= [:dont_test_mpi , :minimal ] setup= [TestCases] begin
205+ =# tags= [:dont_test_mpi , :minimal ] setup= [TestCases, ForwardDiffWrappers ] begin
164206 using DFTK
165207 using ForwardDiff
166208 using LinearAlgebra
@@ -187,11 +229,12 @@ end
187229 derivative_ε = let ε = 1e-5
188230 (compute_force (ε) - compute_force (- ε)) / 2 ε
189231 end
190- derivative_fd = ForwardDiff . derivative (compute_force, 0.0 )
232+ derivative_fd = tagged_derivative (compute_force, 0.0 )
191233 @test norm (derivative_ε - derivative_fd) < 1e-4
192234end
193235
194- @testitem " Derivative of complex function" tags= [:dont_test_mpi , :minimal ] begin
236+ @testitem " Derivative of complex function" #=
237+ =# tags= [:dont_test_mpi , :minimal ] setup= [ForwardDiffWrappers] begin
195238 using DFTK
196239 using ForwardDiff
197240 using LinearAlgebra
@@ -202,12 +245,13 @@ end
202245 erfcα = x -> erfc (α * x)
203246
204247 x0 = randn ()
205- fd1 = ForwardDiff . derivative (erfcα, x0)
248+ fd1 = tagged_derivative (erfcα, x0)
206249 fd2 = FiniteDifferences. central_fdm (5 , 1 )(erfcα, x0)
207250 @test norm (fd1 - fd2) < 1e-8
208251end
209252
210- @testitem " Higher derivatives of Fermi-Dirac occupation" tags= [:dont_test_mpi , :minimal ] begin
253+ @testitem " Higher derivatives of Fermi-Dirac occupation" #=
254+ =# tags= [:dont_test_mpi , :minimal ] setup= [ForwardDiffWrappers] begin
211255 using DFTK
212256 using ForwardDiff
213257
232276 end
233277end
234278
235- @testitem " LocalNonlinearity sensitivity using ForwardDiff" tags= [:dont_test_mpi , :minimal ] begin
279+ @testitem " LocalNonlinearity sensitivity using ForwardDiff" #=
280+ =# tags= [:dont_test_mpi , :minimal ] setup= [ForwardDiffWrappers] begin
236281 using DFTK
237282 using ForwardDiff
238283 using LinearAlgebra
@@ -257,11 +302,12 @@ end
257302 derivative_ε = let ε = 1e-5
258303 (compute_force (ε) - compute_force (- ε)) / 2 ε
259304 end
260- derivative_fd = ForwardDiff . derivative (compute_force, 0.0 )
305+ derivative_fd = tagged_derivative (compute_force, 0.0 )
261306 @test norm (derivative_ε - derivative_fd) < 1e-4
262307end
263308
264- @testitem " Symmetries broken by perturbation are filtered out" tags= [:dont_test_mpi ] begin
309+ @testitem " Symmetries broken by perturbation are filtered out" #=
310+ =# tags= [:dont_test_mpi ] setup= [ForwardDiffWrappers] begin
265311 using DFTK
266312 using ForwardDiff
267313 using LinearAlgebra
343389end
344390
345391@testitem " Symmetry-breaking perturbation using ForwardDiff" #=
346- =# tags= [:dont_test_mpi ] setup= [TestCases] begin
392+ =# tags= [:dont_test_mpi ] setup= [TestCases, ForwardDiffWrappers ] begin
347393 using DFTK
348394 using ForwardDiff
349395 using LinearAlgebra
373419 self_consistent_field (basis; tol= 1e-10 )
374420 end
375421
376- δρ = ForwardDiff . derivative (ε -> run_scf (ε). ρ, 0. )
422+ δρ = tagged_derivative (ε -> run_scf (ε). ρ, 0. )
377423
378424 h = 1e-5
379425 scfres1 = run_scf (- h)
386432end
387433
388434@testitem " Test scfres dual has the same params as scfres primal" #=
389- =# tags= [:dont_test_mpi ] setup= [TestCases] begin
435+ =# tags= [:dont_test_mpi ] setup= [TestCases, ForwardDiffWrappers ] begin
390436 using DFTK
391437 using ForwardDiff
392438 using LinearAlgebra
416462end
417463
418464
419- @testitem " ForwardDiff wrt temperature" tags= [:dont_test_mpi , :minimal ] begin
465+ @testitem " ForwardDiff wrt temperature" #=
466+ =# tags= [:dont_test_mpi , :minimal ] setup= [ForwardDiffWrappers] begin
420467 using DFTK
421468 using ForwardDiff
422469 using LinearAlgebra
440487 derivative_ε = let ε = 1e-5
441488 (get (T0+ ε) - get (T0- ε)) / 2 ε
442489 end
443- derivative_fd = ForwardDiff . derivative (get, T0)
490+ derivative_fd = tagged_derivative (get, T0)
444491 @test norm (derivative_ε - derivative_fd) < 1e-4
445492end
0 commit comments