Skip to content

Commit e2672b6

Browse files
authored
ForwardDiff wrappers to reduce compilation time (#1182)
1 parent c785222 commit e2672b6

File tree

1 file changed

+71
-24
lines changed

1 file changed

+71
-24
lines changed

test/forwarddiff.jl

Lines changed: 71 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,47 @@
1+
# Wrappers around ForwardDiff to fix tags and reduce compilation time.
2+
# Warning: fixing types without care can lead to perturbation confusion,
3+
# this should only be used within the testing framework. Risk of perturbation
4+
# confusion arises when nested derivatives of different functions are taken
5+
# with a fixed tag. Only use these wrappers at the top-level call.
6+
@testmodule ForwardDiffWrappers begin
7+
using ForwardDiff
8+
export tagged_derivative, tagged_gradient, tagged_jacobian
9+
10+
struct DerivativeTag end
11+
function tagged_derivative(f, x::T; custom_tag=DerivativeTag) where T
12+
# explicit call to ForwardDiff.Tag() to trigger ForwardDiff.tagcount
13+
TagType = typeof(ForwardDiff.Tag(custom_tag(), T))
14+
x_dual = ForwardDiff.Dual{TagType, T, 1}(x, ForwardDiff.Partials((one(T),)))
15+
16+
res = ForwardDiff.extract_derivative(TagType, f(x_dual))
17+
return res
18+
end
19+
20+
struct GradientTag end
21+
GradientConfig(f, x, ::Type{Tag}) where {Tag} =
22+
ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk(x), Tag())
23+
function tagged_gradient(f, x::AbstractArray{T}; custom_tag=GradientTag) where T
24+
# explicit call to ForwardDiff.Tag() to trigger ForwardDiff.tagcount
25+
TagType = typeof(ForwardDiff.Tag(custom_tag(), T))
26+
27+
cfg = GradientConfig(f, x, TagType)
28+
ForwardDiff.gradient(f, x, cfg, Val{false}())
29+
end
30+
31+
struct JacobianTag end
32+
JacobianConfig(f, x, ::Type{Tag}) where {Tag} =
33+
ForwardDiff.JacobianConfig(f, x, ForwardDiff.Chunk(x), Tag())
34+
function tagged_jacobian(f, x::AbstractArray{T}; custom_tag=JacobianTag) where T
35+
# explicit call to ForwardDiff.Tag() to trigger ForwardDiff.tagcount
36+
TagType = typeof(ForwardDiff.Tag(custom_tag(), T))
37+
38+
cfg = JacobianConfig(f, x, TagType)
39+
ForwardDiff.jacobian(f, x, cfg, Val{false}())
40+
end
41+
end
42+
143
@testitem "Force derivatives using ForwardDiff" #=
2-
=# tags=[:dont_test_mpi, :minimal] setup=[TestCases] begin
44+
=# tags=[:dont_test_mpi, :minimal] setup=[TestCases, ForwardDiffWrappers] begin
345
using DFTK
446
using ForwardDiff
547
using LinearAlgebra
@@ -28,21 +70,21 @@
2870
derivative_ε1_fd = let ε1 = 1e-5
2971
(compute_force(ε1, 0.0) - F) / ε1
3072
end
31-
derivative_ε1 = ForwardDiff.derivative(ε1 -> compute_force(ε1, 0.0), 0.0)
73+
derivative_ε1 = tagged_derivative(ε1 -> compute_force(ε1, 0.0), 0.0)
3274
@test norm(derivative_ε1 - derivative_ε1_fd) < 1e-4
3375

3476
derivative_ε2_fd = let ε2 = 1e-5
3577
(compute_force(0.0, ε2) - F) / ε2
3678
end
37-
derivative_ε2 = ForwardDiff.derivative(ε2 -> compute_force(0.0, ε2), 0.0)
79+
derivative_ε2 = tagged_derivative(ε2 -> compute_force(0.0, ε2), 0.0)
3880
@test norm(derivative_ε2 - derivative_ε2_fd) < 1e-4
3981

4082
@testset "Multiple partials" begin
41-
grad = ForwardDiff.gradient(v -> compute_force(v...)[1][1], [0.0, 0.0])
83+
grad = tagged_gradient(v -> compute_force(v...)[1][1], [0.0, 0.0])
4284
@test abs(grad[1] - derivative_ε1[1][1]) < 1e-4
4385
@test abs(grad[2] - derivative_ε2[1][1]) < 1e-4
4486

45-
jac = ForwardDiff.jacobian(v -> compute_force(v...)[1], [0.0, 0.0])
87+
jac = tagged_jacobian(v -> compute_force(v...)[1], [0.0, 0.0])
4688
@test norm(grad - jac[1, :]) < 1e-9
4789
end
4890

@@ -51,7 +93,7 @@
5193
derivative_ε1_fd = let ε1 = 1e-5
5294
(compute_force(ε1, 0.0; metal) - compute_force(-ε1, 0.0; metal)) / 2ε1
5395
end
54-
derivative_ε1 = ForwardDiff.derivative(ε1 -> compute_force(ε1, 0.0; metal), 0.0)
96+
derivative_ε1 = tagged_derivative(ε1 -> compute_force(ε1, 0.0; metal), 0.0)
5597
@test norm(derivative_ε1 - derivative_ε1_fd) < 1e-4
5698
end
5799

@@ -62,13 +104,13 @@
62104
derivative_ε1_fd = let ε1 = 1e-5
63105
(compute_force(ε1, 0.0; atoms) - compute_force(-ε1, 0.0; atoms)) / 2ε1
64106
end
65-
derivative_ε1 = ForwardDiff.derivative(ε1 -> compute_force(ε1, 0.0; atoms), 0.0)
107+
derivative_ε1 = tagged_derivative(ε1 -> compute_force(ε1, 0.0; atoms), 0.0)
66108
@test norm(derivative_ε1 - derivative_ε1_fd) < 1e-4
67109
end
68110
end
69111

70112
@testitem "Anisotropic strain sensitivity using ForwardDiff" #=
71-
=# tags=[:dont_test_mpi, :minimal] setup=[TestCases] begin
113+
=# tags=[:dont_test_mpi, :minimal] setup=[TestCases, ForwardDiffWrappers] begin
72114
using DFTK
73115
using ForwardDiff
74116
using LinearAlgebra
@@ -103,7 +145,7 @@ end
103145

104146
@testset "$strain_fn" for strain_fn in (strain_isotropic, strain_anisotropic)
105147
f(ε) = compute_properties(strain_fn(ε))
106-
dx = ForwardDiff.derivative(f, 0.)
148+
dx = tagged_derivative(f, 0.)
107149

108150
h = 1e-4
109151
x1 = f(-h)
@@ -118,7 +160,7 @@ end
118160
end
119161

120162
@testitem "scfres PSP sensitivity using ForwardDiff" #=
121-
=# tags=[:dont_test_mpi, :minimal] setup=[TestCases] begin
163+
=# tags=[:dont_test_mpi, :minimal] setup=[TestCases, ForwardDiffWrappers] begin
122164
using DFTK
123165
using ForwardDiff
124166
using LinearAlgebra
@@ -155,12 +197,12 @@ end
155197
derivative_ε = let ε = 1e-4
156198
(compute_band_energies(ε) - compute_band_energies(-ε)) / 2ε
157199
end
158-
derivative_fd = ForwardDiff.derivative(compute_band_energies, 0.0)
200+
derivative_fd = tagged_derivative(compute_band_energies, 0.0)
159201
@test norm(derivative_fd - derivative_ε) < 5e-4
160202
end
161203

162204
@testitem "Functional force sensitivity using ForwardDiff" #=
163-
=# tags=[:dont_test_mpi, :minimal] setup=[TestCases] begin
205+
=# tags=[:dont_test_mpi, :minimal] setup=[TestCases, ForwardDiffWrappers] begin
164206
using DFTK
165207
using ForwardDiff
166208
using LinearAlgebra
@@ -187,11 +229,12 @@ end
187229
derivative_ε = let ε = 1e-5
188230
(compute_force(ε) - compute_force(-ε)) / 2ε
189231
end
190-
derivative_fd = ForwardDiff.derivative(compute_force, 0.0)
232+
derivative_fd = tagged_derivative(compute_force, 0.0)
191233
@test norm(derivative_ε - derivative_fd) < 1e-4
192234
end
193235

194-
@testitem "Derivative of complex function" tags=[:dont_test_mpi, :minimal] begin
236+
@testitem "Derivative of complex function" #=
237+
=# tags=[:dont_test_mpi, :minimal] setup=[ForwardDiffWrappers] begin
195238
using DFTK
196239
using ForwardDiff
197240
using LinearAlgebra
@@ -202,12 +245,13 @@ end
202245
erfcα = x -> erfc* x)
203246

204247
x0 = randn()
205-
fd1 = ForwardDiff.derivative(erfcα, x0)
248+
fd1 = tagged_derivative(erfcα, x0)
206249
fd2 = FiniteDifferences.central_fdm(5, 1)(erfcα, x0)
207250
@test norm(fd1 - fd2) < 1e-8
208251
end
209252

210-
@testitem "Higher derivatives of Fermi-Dirac occupation" tags=[:dont_test_mpi, :minimal] begin
253+
@testitem "Higher derivatives of Fermi-Dirac occupation" #=
254+
=# tags=[:dont_test_mpi, :minimal] setup=[ForwardDiffWrappers] begin
211255
using DFTK
212256
using ForwardDiff
213257

@@ -232,7 +276,8 @@ end
232276
end
233277
end
234278

235-
@testitem "LocalNonlinearity sensitivity using ForwardDiff" tags=[:dont_test_mpi, :minimal] begin
279+
@testitem "LocalNonlinearity sensitivity using ForwardDiff" #=
280+
=# tags=[:dont_test_mpi, :minimal] setup=[ForwardDiffWrappers] begin
236281
using DFTK
237282
using ForwardDiff
238283
using LinearAlgebra
@@ -257,11 +302,12 @@ end
257302
derivative_ε = let ε = 1e-5
258303
(compute_force(ε) - compute_force(-ε)) / 2ε
259304
end
260-
derivative_fd = ForwardDiff.derivative(compute_force, 0.0)
305+
derivative_fd = tagged_derivative(compute_force, 0.0)
261306
@test norm(derivative_ε - derivative_fd) < 1e-4
262307
end
263308

264-
@testitem "Symmetries broken by perturbation are filtered out" tags=[:dont_test_mpi] begin
309+
@testitem "Symmetries broken by perturbation are filtered out" #=
310+
=# tags=[:dont_test_mpi] setup=[ForwardDiffWrappers] begin
265311
using DFTK
266312
using ForwardDiff
267313
using LinearAlgebra
@@ -343,7 +389,7 @@ end
343389
end
344390

345391
@testitem "Symmetry-breaking perturbation using ForwardDiff" #=
346-
=# tags=[:dont_test_mpi] setup=[TestCases] begin
392+
=# tags=[:dont_test_mpi] setup=[TestCases, ForwardDiffWrappers] begin
347393
using DFTK
348394
using ForwardDiff
349395
using LinearAlgebra
@@ -373,7 +419,7 @@ end
373419
self_consistent_field(basis; tol=1e-10)
374420
end
375421

376-
δρ = ForwardDiff.derivative-> run_scf(ε).ρ, 0.)
422+
δρ = tagged_derivative-> run_scf(ε).ρ, 0.)
377423

378424
h = 1e-5
379425
scfres1 = run_scf(-h)
@@ -386,7 +432,7 @@ end
386432
end
387433

388434
@testitem "Test scfres dual has the same params as scfres primal" #=
389-
=# tags=[:dont_test_mpi] setup=[TestCases] begin
435+
=# tags=[:dont_test_mpi] setup=[TestCases, ForwardDiffWrappers] begin
390436
using DFTK
391437
using ForwardDiff
392438
using LinearAlgebra
@@ -416,7 +462,8 @@ end
416462
end
417463

418464

419-
@testitem "ForwardDiff wrt temperature" tags=[:dont_test_mpi, :minimal] begin
465+
@testitem "ForwardDiff wrt temperature" #=
466+
=# tags=[:dont_test_mpi, :minimal] setup=[ForwardDiffWrappers] begin
420467
using DFTK
421468
using ForwardDiff
422469
using LinearAlgebra
@@ -440,6 +487,6 @@ end
440487
derivative_ε = let ε = 1e-5
441488
(get(T0+ε) - get(T0-ε)) / 2ε
442489
end
443-
derivative_fd = ForwardDiff.derivative(get, T0)
490+
derivative_fd = tagged_derivative(get, T0)
444491
@test norm(derivative_ε - derivative_fd) < 1e-4
445492
end

0 commit comments

Comments
 (0)