Skip to content

Commit a5c9d8d

Browse files
committed
Test successful retcode from different AD
1 parent e0cf88a commit a5c9d8d

File tree

4 files changed

+82
-68
lines changed

4 files changed

+82
-68
lines changed

lib/BoundaryValueDiffEqFIRK/Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ ConcreteStructs = "0.2.3"
3737
DiffEqBase = "6.158.3"
3838
DiffEqDevTools = "2.44"
3939
DifferentiationInterface = "0.6.42"
40+
Enzyme = "0.13.33"
4041
FastAlmostBandedMatrices = "0.1.4"
4142
FastClosures = "0.3.2"
4243
ForwardDiff = "0.10.38, 1"
@@ -45,6 +46,7 @@ InteractiveUtils = "<0.0.1, 1"
4546
JET = "0.9"
4647
LinearAlgebra = "1.10"
4748
LinearSolve = "2.36.2, 3"
49+
Mooncake = "0.4.108"
4850
OrdinaryDiffEq = "6.90.1"
4951
PreallocationTools = "0.4.24"
5052
PrecompileTools = "1.2"
@@ -63,10 +65,12 @@ julia = "1.10"
6365
[extras]
6466
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
6567
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
68+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
6669
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
6770
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
6871
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
6972
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
73+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
7074
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
7175
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
7276
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
@@ -75,4 +79,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
7579
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7680

7781
[targets]
78-
test = ["Aqua", "DiffEqDevTools", "Hwloc", "InteractiveUtils", "JET", "LinearSolve", "OrdinaryDiffEq", "Random", "ReTestItems", "RecursiveArrayTools", "StaticArrays", "Test"]
82+
test = ["Aqua", "DiffEqDevTools", "Enzyme", "Hwloc", "InteractiveUtils", "JET", "LinearSolve", "Mooncake", "OrdinaryDiffEq", "Random", "ReTestItems", "RecursiveArrayTools", "StaticArrays", "Test"]

lib/BoundaryValueDiffEqFIRK/test/expanded/ad_tests.jl

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,41 +28,43 @@
2828
nonbc_diffmode = AutoEnzyme(
2929
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
3030
for jac_alg in [jac_alg_forwarddiff, jac_alg_enzyme, jac_alg_mooncake]
31-
@test_nowarn sol = solve(
31+
sol = solve(
3232
prob, RadauIIa5(; jac_alg = jac_alg, nested_nlsolve = false), dt = 0.05)
33+
@test SciMLBase.successful_retcode(sol)
3334
end
3435
end
35-
#=
36-
@testset "Test different AD on multipoint BVP using Interpolation BC" begin
37-
function simplependulum!(du, u, p, t)
38-
θ = u[1]
39-
dθ = u[2]
40-
du[1] = dθ
41-
du[2] = -9.81 * sin(θ)
42-
end
43-
function bc!(residual, u, p, t)
44-
residual[1] = u(pi / 4)[1] + pi / 2
45-
residual[2] = u(pi / 2)[1] - pi / 2
46-
end
47-
u0 = [pi / 2, pi / 2]
48-
tspan = (0.0, pi / 2)
49-
prob = BVProblem(simplependulum!, bc!, u0, tspan)
50-
jac_alg_forwarddiff = BVPJacobianAlgorithm(
51-
bc_diffmode = AutoSparse(AutoForwardDiff()), nonbc_diffmode = AutoForwardDiff())
52-
jac_alg_enzyme = BVPJacobianAlgorithm(
53-
bc_diffmode = AutoSparse(AutoEnzyme(
54-
mode = Enzyme.Reverse, function_annotation = Enzyme.Duplicated)),
55-
nonbc_diffmode = AutoEnzyme(
56-
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
57-
jac_alg_mooncake = BVPJacobianAlgorithm(
58-
bc_diffmode = AutoSparse(AutoMooncake(; config = nothing)),
59-
nonbc_diffmode = AutoEnzyme(
60-
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
61-
for jac_alg in [jac_alg_forwarddiff, jac_alg_enzyme, jac_alg_mooncake]
62-
@test_nowarn sol = solve(prob, RadauIIa5(; jac_alg = jac_alg), dt = 0.05)
63-
end
36+
37+
@testset "Test different AD on multipoint BVP using Interpolation BC" begin
38+
function simplependulum!(du, u, p, t)
39+
θ = u[1]
40+
= u[2]
41+
du[1] =
42+
du[2] = -9.81 * sin(θ)
43+
end
44+
function bc!(residual, u, p, t)
45+
residual[1] = u(pi / 4)[1] + pi / 2
46+
residual[2] = u(pi / 2)[1] - pi / 2
47+
end
48+
u0 = [pi / 2, pi / 2]
49+
tspan = (0.0, pi / 2)
50+
prob = BVProblem(simplependulum!, bc!, u0, tspan)
51+
jac_alg_forwarddiff = BVPJacobianAlgorithm(
52+
bc_diffmode = AutoSparse(AutoForwardDiff()), nonbc_diffmode = AutoForwardDiff())
53+
jac_alg_enzyme = BVPJacobianAlgorithm(
54+
bc_diffmode = AutoSparse(AutoEnzyme(
55+
mode = Enzyme.Reverse, function_annotation = Enzyme.Duplicated)),
56+
nonbc_diffmode = AutoEnzyme(
57+
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
58+
jac_alg_mooncake = BVPJacobianAlgorithm(
59+
bc_diffmode = AutoSparse(AutoMooncake(; config = nothing)),
60+
nonbc_diffmode = AutoEnzyme(
61+
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
62+
for jac_alg in [jac_alg_forwarddiff, jac_alg_enzyme, jac_alg_mooncake]
63+
sol = solve(prob, RadauIIa5(; jac_alg = jac_alg), dt = 0.05)
64+
@test SciMLBase.successful_retcode(sol)
6465
end
65-
=#
66+
end
67+
6668
@testset "Test different AD on twopoint BVP" begin
6769
function f!(du, u, p, t)
6870
du[1] = u[2]
@@ -87,8 +89,9 @@
8789
jac_alg_mooncake = BVPJacobianAlgorithm(AutoSparse(AutoMooncake(;
8890
config = nothing)))
8991
for jac_alg in [jac_alg_forwarddiff, jac_alg_enzyme, jac_alg_mooncake]
90-
@test_nowarn sol = solve(
92+
sol = solve(
9193
prob, RadauIIa5(; jac_alg = jac_alg, nested_nlsolve = false), dt = 0.01)
94+
@test SciMLBase.successful_retcode(sol)
9295
end
9396
end
9497
end

lib/BoundaryValueDiffEqFIRK/test/nested/ad_tests.jl

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,41 +28,44 @@
2828
nonbc_diffmode = AutoEnzyme(
2929
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
3030
for jac_alg in [jac_alg_forwarddiff, jac_alg_enzyme, jac_alg_mooncake]
31-
@test_nowarn sol = solve(
31+
sol = solve(
3232
prob, RadauIIa5(; jac_alg = jac_alg, nested_nlsolve = true), dt = 0.05)
33+
@test SciMLBase.successful_retcode(sol)
3334
end
3435
end
35-
#=
36-
@testset "Test different AD on multipoint BVP using Interpolation BC" begin
37-
function simplependulum!(du, u, p, t)
38-
θ = u[1]
39-
dθ = u[2]
40-
du[1] = dθ
41-
du[2] = -9.81 * sin(θ)
42-
end
43-
function bc!(residual, u, p, t)
44-
residual[1] = u(pi / 4)[1] + pi / 2
45-
residual[2] = u(pi / 2)[1] - pi / 2
46-
end
47-
u0 = [pi / 2, pi / 2]
48-
tspan = (0.0, pi / 2)
49-
prob = BVProblem(simplependulum!, bc!, u0, tspan)
50-
jac_alg_forwarddiff = BVPJacobianAlgorithm(
51-
bc_diffmode = AutoSparse(AutoForwardDiff()), nonbc_diffmode = AutoForwardDiff())
52-
jac_alg_enzyme = BVPJacobianAlgorithm(
53-
bc_diffmode = AutoSparse(AutoEnzyme(
54-
mode = Enzyme.Reverse, function_annotation = Enzyme.Duplicated)),
55-
nonbc_diffmode = AutoEnzyme(
56-
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
57-
jac_alg_mooncake = BVPJacobianAlgorithm(
58-
bc_diffmode = AutoSparse(AutoMooncake(; config = nothing)),
59-
nonbc_diffmode = AutoEnzyme(
60-
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
61-
for jac_alg in [jac_alg_forwarddiff, jac_alg_enzyme, jac_alg_mooncake]
62-
@test_nowarn sol = solve(prob, RadauIIa5(; jac_alg = jac_alg, nested_nlsolve = true), dt = 0.05)
63-
end
36+
37+
@testset "Test different AD on multipoint BVP using Interpolation BC" begin
38+
function simplependulum!(du, u, p, t)
39+
θ = u[1]
40+
= u[2]
41+
du[1] =
42+
du[2] = -9.81 * sin(θ)
43+
end
44+
function bc!(residual, u, p, t)
45+
residual[1] = u(pi / 4)[1] + pi / 2
46+
residual[2] = u(pi / 2)[1] - pi / 2
6447
end
65-
=#
48+
u0 = [pi / 2, pi / 2]
49+
tspan = (0.0, pi / 2)
50+
prob = BVProblem(simplependulum!, bc!, u0, tspan)
51+
jac_alg_forwarddiff = BVPJacobianAlgorithm(
52+
bc_diffmode = AutoSparse(AutoForwardDiff()), nonbc_diffmode = AutoForwardDiff())
53+
jac_alg_enzyme = BVPJacobianAlgorithm(
54+
bc_diffmode = AutoSparse(AutoEnzyme(
55+
mode = Enzyme.Reverse, function_annotation = Enzyme.Duplicated)),
56+
nonbc_diffmode = AutoEnzyme(
57+
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
58+
jac_alg_mooncake = BVPJacobianAlgorithm(
59+
bc_diffmode = AutoSparse(AutoMooncake(; config = nothing)),
60+
nonbc_diffmode = AutoEnzyme(
61+
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
62+
for jac_alg in [jac_alg_forwarddiff, jac_alg_enzyme, jac_alg_mooncake]
63+
sol = solve(
64+
prob, RadauIIa5(; jac_alg = jac_alg, nested_nlsolve = true), dt = 0.05)
65+
@test SciMLBase.successful_retcode(sol)
66+
end
67+
end
68+
6669
@testset "Test different AD on twopoint BVP" begin
6770
function f!(du, u, p, t)
6871
du[1] = u[2]
@@ -87,8 +90,9 @@
8790
jac_alg_mooncake = BVPJacobianAlgorithm(AutoSparse(AutoMooncake(;
8891
config = nothing)))
8992
for jac_alg in [jac_alg_forwarddiff, jac_alg_enzyme, jac_alg_mooncake]
90-
@test_nowarn sol = solve(
93+
sol = solve(
9194
prob, RadauIIa5(; jac_alg = jac_alg, nested_nlsolve = true), dt = 0.01)
95+
@test SciMLBase.successful_retcode(sol)
9296
end
9397
end
9498
end

lib/BoundaryValueDiffEqMIRK/test/ad_tests.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
nonbc_diffmode = AutoEnzyme(
2929
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
3030
for jac_alg in [jac_alg_forwarddiff, jac_alg_enzyme, jac_alg_mooncake]
31-
@test_nowarn sol = solve(prob, MIRK4(; jac_alg = jac_alg), dt = 0.05)
31+
sol = solve(prob, MIRK4(; jac_alg = jac_alg), dt = 0.05)
32+
@test SciMLBase.successful_retcode(sol)
3233
end
3334
end
3435

@@ -58,7 +59,8 @@
5859
nonbc_diffmode = AutoEnzyme(
5960
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
6061
for jac_alg in [jac_alg_forwarddiff, jac_alg_enzyme, jac_alg_mooncake]
61-
@test_nowarn sol = solve(prob, MIRK4(; jac_alg = jac_alg), dt = 0.05)
62+
sol = solve(prob, MIRK4(; jac_alg = jac_alg), dt = 0.05)
63+
@test SciMLBase.successful_retcode(sol)
6264
end
6365
end
6466

@@ -86,7 +88,8 @@
8688
jac_alg_mooncake = BVPJacobianAlgorithm(AutoSparse(AutoMooncake(;
8789
config = nothing)))
8890
for jac_alg in [jac_alg_forwarddiff, jac_alg_enzyme, jac_alg_mooncake]
89-
@test_nowarn sol = solve(prob, MIRK4(; jac_alg = jac_alg), dt = 0.01)
91+
sol = solve(prob, MIRK4(; jac_alg = jac_alg), dt = 0.01)
92+
@test SciMLBase.successful_retcode(sol)
9093
end
9194
end
9295
end

0 commit comments

Comments
 (0)