Skip to content

Commit ccaa5c7

Browse files
authored
Add integration tests for Checkpointing (#2811)
1 parent 30e0519 commit ccaa5c7

File tree

3 files changed

+74
-0
lines changed

3 files changed

+74
-0
lines changed

.github/workflows/Integration.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ jobs:
4646
- linux-x86-n2-32
4747
package:
4848
- Bijectors
49+
- Checkpointing
4950
- DifferentiationInterface
5051
- Distributions
5152
- DynamicExpressions
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[deps]
2+
Checkpointing = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca"
3+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
5+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
6+
7+
[compat]
8+
Checkpointing = "0.11"
9+
10+
[sources.Enzyme]
11+
path = "../../.."
12+
13+
[sources.EnzymeCore]
14+
path = "../../../lib/EnzymeCore"
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
using Test
2+
using Enzyme
3+
using Checkpointing
4+
using LinearAlgebra
5+
6+
# Explicit 1D heat equation
7+
mutable struct Heat
8+
Tnext::Vector{Float64}
9+
Tlast::Vector{Float64}
10+
n::Int
11+
λ::Float64
12+
tsteps::Int
13+
end
14+
15+
function advance(heat)
16+
next = heat.Tnext
17+
last = heat.Tlast
18+
λ = heat.λ
19+
n = heat.n
20+
for i in 2:(n - 1)
21+
next[i] = last[i] + λ * (last[i - 1] - 2 * last[i] + last[i + 1])
22+
end
23+
return nothing
24+
end
25+
26+
27+
function sumheat(heat::Heat, chkpscheme::Scheme, tsteps::Int64)
28+
@ad_checkpoint chkpscheme for i in 1:tsteps
29+
heat.Tlast .= heat.Tnext
30+
advance(heat)
31+
end
32+
return reduce(+, heat.Tnext)
33+
end
34+
35+
function heat(scheme::Scheme, tsteps::Int)
36+
n = 100
37+
Δx = 0.1
38+
Δt = 0.001
39+
# Select μ such that λ ≤ 0.5 for stability with μ = (λ*Δt)/Δx^2
40+
λ = 0.5
41+
42+
# Create object from struct. tsteps is not needed for a for-loop
43+
heat = Heat(zeros(n), zeros(n), n, λ, tsteps)
44+
# Shadow copy for Enzyme
45+
dheat = Heat(zeros(n), zeros(n), n, λ, tsteps)
46+
47+
# Boundary conditions
48+
heat.Tnext[1] = 20.0
49+
heat.Tnext[end] = 0
50+
51+
# Compute gradient
52+
autodiff(Enzyme.ReverseWithPrimal, sumheat, Duplicated(heat, dheat), Const(scheme), Const(tsteps))
53+
54+
return heat.Tnext, dheat.Tnext[2:(end - 1)]
55+
end
56+
57+
T, dT = heat(Revolve(4), 500)
58+
@test norm(T) == 66.21987468492061
59+
@test norm(dT) == 6.970279349365908

0 commit comments

Comments
 (0)