Skip to content

Commit e0cf88a

Browse files
committed
Fix FIRK adaptivity
1 parent 9c1018d commit e0cf88a

File tree

3 files changed

+282
-2
lines changed

3 files changed

+282
-2
lines changed

lib/BoundaryValueDiffEqFIRK/src/adaptivity.jl

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
After we construct an interpolant, we use interp_eval to evaluate it.
66
"""
77
@views function interp_eval!(
8-
y::AbstractArray, cache::FIRKCacheExpand{iip}, t, mesh, mesh_dt) where {iip}
8+
y::AbstractArray, cache::FIRKCacheExpand{iip, T, DiffCacheNeeded},
9+
t, mesh, mesh_dt) where {iip, T}
910
j = interval(mesh, t)
1011
h = mesh_dt[j]
1112
lf = (length(cache.y₀) - 1) / (length(cache.y) - 1) # Cache length factor. We use a h corresponding to cache.y. Note that this assumes equidistributed mesh
@@ -48,7 +49,52 @@ After we construct an interpolant, we use interp_eval to evaluate it.
4849
end
4950

5051
@views function interp_eval!(
51-
y::AbstractArray, cache::FIRKCacheNested{iip, T}, t, mesh, mesh_dt) where {iip, T}
52+
y::AbstractArray, cache::FIRKCacheExpand{iip, T, NoDiffCacheNeeded},
53+
t, mesh, mesh_dt) where {iip, T}
54+
j = interval(mesh, t)
55+
h = mesh_dt[j]
56+
lf = (length(cache.y₀) - 1) / (length(cache.y) - 1) # Cache length factor. We use a h corresponding to cache.y. Note that this assumes equidistributed mesh
57+
if lf > 1
58+
h *= lf
59+
end
60+
τ = (t - mesh[j])
61+
62+
(; f, M, stage, p, ITU) = cache
63+
(; q_coeff) = ITU
64+
65+
K = safe_similar(cache.y[1], M, stage)
66+
67+
ctr_y = (j - 1) * (stage + 1) + 1
68+
69+
yᵢ = cache.y[ctr_y]
70+
yᵢ₊₁ = cache.y[ctr_y + stage + 1]
71+
72+
if iip
73+
dyᵢ = similar(yᵢ)
74+
dyᵢ₊₁ = similar(yᵢ₊₁)
75+
76+
f(dyᵢ, yᵢ, p, mesh[j])
77+
f(dyᵢ₊₁, yᵢ₊₁, p, mesh[j + 1])
78+
else
79+
dyᵢ = f(yᵢ, p, mesh[j])
80+
dyᵢ₊₁ = f(yᵢ₊₁, p, mesh[j + 1])
81+
end
82+
83+
# Load interpolation residual
84+
for jj in 1:stage
85+
K[:, jj] = cache.y[ctr_y + jj]
86+
end
87+
88+
z₁, z₁′ = eval_q(yᵢ, 0.5, h, q_coeff, K) # Evaluate q(x) at midpoints
89+
S_coeffs = get_S_coeffs(h, yᵢ, yᵢ₊₁, z₁, dyᵢ, dyᵢ₊₁, z₁′)
90+
91+
S_interpolate!(y, τ, S_coeffs)
92+
return y
93+
end
94+
95+
@views function interp_eval!(
96+
y::AbstractArray, cache::FIRKCacheNested{iip, T, DiffCacheNeeded},
97+
t, mesh, mesh_dt) where {iip, T}
5298
(; f, ITU, nest_prob, alg) = cache
5399
(; q_coeff) = ITU
54100

@@ -92,6 +138,52 @@ end
92138
return y
93139
end
94140

141+
@views function interp_eval!(
142+
y::AbstractArray, cache::FIRKCacheNested{iip, T, NoDiffCacheNeeded},
143+
t, mesh, mesh_dt) where {iip, T}
144+
(; f, ITU, nest_prob, alg) = cache
145+
(; q_coeff) = ITU
146+
147+
j = interval(mesh, t)
148+
h = mesh_dt[j]
149+
lf = (length(cache.y₀) - 1) / (length(cache.y) - 1) # Cache length factor. We use a h corresponding to cache.y. Note that this assumes equidistributed mesh
150+
if lf > 1
151+
h *= lf
152+
end
153+
τ = (t - mesh[j])
154+
155+
nest_nlsolve_alg = __concrete_nonlinearsolve_algorithm(nest_prob, alg.nlsolve)
156+
nestprob_p = zeros(T, cache.M + 2)
157+
158+
yᵢ = copy(cache.y[j])
159+
yᵢ₊₁ = copy(cache.y[j + 1])
160+
161+
if iip
162+
dyᵢ = similar(yᵢ)
163+
dyᵢ₊₁ = similar(yᵢ₊₁)
164+
165+
f(dyᵢ, yᵢ, cache.p, mesh[j])
166+
f(dyᵢ₊₁, yᵢ₊₁, cache.p, mesh[j + 1])
167+
else
168+
dyᵢ = f(yᵢ, cache.p, mesh[j])
169+
dyᵢ₊₁ = f(yᵢ₊₁, cache.p, mesh[j + 1])
170+
end
171+
172+
nestprob_p[1] = mesh[j]
173+
nestprob_p[2] = mesh_dt[j]
174+
nestprob_p[3:end] .= yᵢ
175+
176+
_nestprob = remake(nest_prob, p = nestprob_p)
177+
nestsol = __solve(_nestprob, nest_nlsolve_alg; alg.nested_nlsolve_kwargs...)
178+
K = nestsol.u
179+
180+
z₁, z₁′ = eval_q(yᵢ, 0.5, h, q_coeff, K) # Evaluate q(x) at midpoints
181+
S_coeffs = get_S_coeffs(h, yᵢ, yᵢ₊₁, z₁, dyᵢ, dyᵢ₊₁, z₁′)
182+
183+
S_interpolate!(y, τ, S_coeffs)
184+
return y
185+
end
186+
95187
function get_S_coeffs(h, yᵢ, yᵢ₊₁, dyᵢ, dyᵢ₊₁, ymid, dymid)
96188
vals = vcat(yᵢ, yᵢ₊₁, dyᵢ, dyᵢ₊₁, ymid, dymid)
97189
M = length(yᵢ)
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
@testitem "Different AD compatibility" begin
2+
using BoundaryValueDiffEqFIRK
3+
using ForwardDiff, Enzyme, Mooncake
4+
5+
@testset "Test different AD on multipoint BVP" begin
6+
function simplependulum!(du, u, p, t)
7+
θ = u[1]
8+
= u[2]
9+
du[1] =
10+
du[2] = -9.81 * sin(θ)
11+
end
12+
function bc!(residual, u, p, t)
13+
residual[1] = u[:, end ÷ 2][1] + pi / 2
14+
residual[2] = u[:, end][1] - pi / 2
15+
end
16+
u0 = [pi / 2, pi / 2]
17+
tspan = (0.0, pi / 2)
18+
prob = BVProblem(simplependulum!, bc!, u0, tspan)
19+
jac_alg_forwarddiff = BVPJacobianAlgorithm(
20+
bc_diffmode = AutoSparse(AutoForwardDiff()), nonbc_diffmode = AutoForwardDiff())
21+
jac_alg_enzyme = BVPJacobianAlgorithm(
22+
bc_diffmode = AutoSparse(AutoEnzyme(
23+
mode = Enzyme.Reverse, function_annotation = Enzyme.Duplicated)),
24+
nonbc_diffmode = AutoEnzyme(
25+
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
26+
jac_alg_mooncake = BVPJacobianAlgorithm(
27+
bc_diffmode = AutoSparse(AutoMooncake(; config = nothing)),
28+
nonbc_diffmode = AutoEnzyme(
29+
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
30+
for jac_alg in [jac_alg_forwarddiff, jac_alg_enzyme, jac_alg_mooncake]
31+
@test_nowarn sol = solve(
32+
prob, RadauIIa5(; jac_alg = jac_alg, nested_nlsolve = false), dt = 0.05)
33+
end
34+
end
35+
#=
36+
@testset "Test different AD on multipoint BVP using Interpolation BC" begin
37+
function simplependulum!(du, u, p, t)
38+
θ = u[1]
39+
dθ = u[2]
40+
du[1] = dθ
41+
du[2] = -9.81 * sin(θ)
42+
end
43+
function bc!(residual, u, p, t)
44+
residual[1] = u(pi / 4)[1] + pi / 2
45+
residual[2] = u(pi / 2)[1] - pi / 2
46+
end
47+
u0 = [pi / 2, pi / 2]
48+
tspan = (0.0, pi / 2)
49+
prob = BVProblem(simplependulum!, bc!, u0, tspan)
50+
jac_alg_forwarddiff = BVPJacobianAlgorithm(
51+
bc_diffmode = AutoSparse(AutoForwardDiff()), nonbc_diffmode = AutoForwardDiff())
52+
jac_alg_enzyme = BVPJacobianAlgorithm(
53+
bc_diffmode = AutoSparse(AutoEnzyme(
54+
mode = Enzyme.Reverse, function_annotation = Enzyme.Duplicated)),
55+
nonbc_diffmode = AutoEnzyme(
56+
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
57+
jac_alg_mooncake = BVPJacobianAlgorithm(
58+
bc_diffmode = AutoSparse(AutoMooncake(; config = nothing)),
59+
nonbc_diffmode = AutoEnzyme(
60+
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
61+
for jac_alg in [jac_alg_forwarddiff, jac_alg_enzyme, jac_alg_mooncake]
62+
@test_nowarn sol = solve(prob, RadauIIa5(; jac_alg = jac_alg), dt = 0.05)
63+
end
64+
end
65+
=#
66+
@testset "Test different AD on twopoint BVP" begin
67+
function f!(du, u, p, t)
68+
du[1] = u[2]
69+
du[2] = 0
70+
end
71+
function boundary_two_point_a!(resida, ua, p)
72+
resida[1] = ua[1] - 5
73+
end
74+
function boundary_two_point_b!(residb, ub, p)
75+
residb[1] = ub[1]
76+
end
77+
78+
odef! = ODEFunction(f!, analytic = (u0, p, t) -> [5 - t, -1])
79+
bcresid_prototype = (Array{Float64}(undef, 1), Array{Float64}(undef, 1))
80+
tspan = (0.0, 5.0)
81+
u0 = [5.0, -3.5]
82+
prob = TwoPointBVProblem(odef!, (boundary_two_point_a!, boundary_two_point_b!),
83+
u0, tspan; bcresid_prototype, nlls = Val(false))
84+
jac_alg_forwarddiff = BVPJacobianAlgorithm(AutoSparse(AutoForwardDiff()))
85+
jac_alg_enzyme = BVPJacobianAlgorithm(AutoSparse(AutoEnzyme(
86+
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated)))
87+
jac_alg_mooncake = BVPJacobianAlgorithm(AutoSparse(AutoMooncake(;
88+
config = nothing)))
89+
for jac_alg in [jac_alg_forwarddiff, jac_alg_enzyme, jac_alg_mooncake]
90+
@test_nowarn sol = solve(
91+
prob, RadauIIa5(; jac_alg = jac_alg, nested_nlsolve = false), dt = 0.01)
92+
end
93+
end
94+
end
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
@testitem "Different AD compatibility" begin
2+
using BoundaryValueDiffEqFIRK
3+
using ForwardDiff, Enzyme, Mooncake
4+
5+
@testset "Test different AD on multipoint BVP" begin
6+
function simplependulum!(du, u, p, t)
7+
θ = u[1]
8+
= u[2]
9+
du[1] =
10+
du[2] = -9.81 * sin(θ)
11+
end
12+
function bc!(residual, u, p, t)
13+
residual[1] = u[:, end ÷ 2][1] + pi / 2
14+
residual[2] = u[:, end][1] - pi / 2
15+
end
16+
u0 = [pi / 2, pi / 2]
17+
tspan = (0.0, pi / 2)
18+
prob = BVProblem(simplependulum!, bc!, u0, tspan)
19+
jac_alg_forwarddiff = BVPJacobianAlgorithm(
20+
bc_diffmode = AutoSparse(AutoForwardDiff()), nonbc_diffmode = AutoForwardDiff())
21+
jac_alg_enzyme = BVPJacobianAlgorithm(
22+
bc_diffmode = AutoSparse(AutoEnzyme(
23+
mode = Enzyme.Reverse, function_annotation = Enzyme.Duplicated)),
24+
nonbc_diffmode = AutoEnzyme(
25+
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
26+
jac_alg_mooncake = BVPJacobianAlgorithm(
27+
bc_diffmode = AutoSparse(AutoMooncake(; config = nothing)),
28+
nonbc_diffmode = AutoEnzyme(
29+
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
30+
for jac_alg in [jac_alg_forwarddiff, jac_alg_enzyme, jac_alg_mooncake]
31+
@test_nowarn sol = solve(
32+
prob, RadauIIa5(; jac_alg = jac_alg, nested_nlsolve = true), dt = 0.05)
33+
end
34+
end
35+
#=
36+
@testset "Test different AD on multipoint BVP using Interpolation BC" begin
37+
function simplependulum!(du, u, p, t)
38+
θ = u[1]
39+
dθ = u[2]
40+
du[1] = dθ
41+
du[2] = -9.81 * sin(θ)
42+
end
43+
function bc!(residual, u, p, t)
44+
residual[1] = u(pi / 4)[1] + pi / 2
45+
residual[2] = u(pi / 2)[1] - pi / 2
46+
end
47+
u0 = [pi / 2, pi / 2]
48+
tspan = (0.0, pi / 2)
49+
prob = BVProblem(simplependulum!, bc!, u0, tspan)
50+
jac_alg_forwarddiff = BVPJacobianAlgorithm(
51+
bc_diffmode = AutoSparse(AutoForwardDiff()), nonbc_diffmode = AutoForwardDiff())
52+
jac_alg_enzyme = BVPJacobianAlgorithm(
53+
bc_diffmode = AutoSparse(AutoEnzyme(
54+
mode = Enzyme.Reverse, function_annotation = Enzyme.Duplicated)),
55+
nonbc_diffmode = AutoEnzyme(
56+
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
57+
jac_alg_mooncake = BVPJacobianAlgorithm(
58+
bc_diffmode = AutoSparse(AutoMooncake(; config = nothing)),
59+
nonbc_diffmode = AutoEnzyme(
60+
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated))
61+
for jac_alg in [jac_alg_forwarddiff, jac_alg_enzyme, jac_alg_mooncake]
62+
@test_nowarn sol = solve(prob, RadauIIa5(; jac_alg = jac_alg, nested_nlsolve = true), dt = 0.05)
63+
end
64+
end
65+
=#
66+
@testset "Test different AD on twopoint BVP" begin
67+
function f!(du, u, p, t)
68+
du[1] = u[2]
69+
du[2] = 0
70+
end
71+
function boundary_two_point_a!(resida, ua, p)
72+
resida[1] = ua[1] - 5
73+
end
74+
function boundary_two_point_b!(residb, ub, p)
75+
residb[1] = ub[1]
76+
end
77+
78+
odef! = ODEFunction(f!, analytic = (u0, p, t) -> [5 - t, -1])
79+
bcresid_prototype = (Array{Float64}(undef, 1), Array{Float64}(undef, 1))
80+
tspan = (0.0, 5.0)
81+
u0 = [5.0, -3.5]
82+
prob = TwoPointBVProblem(odef!, (boundary_two_point_a!, boundary_two_point_b!),
83+
u0, tspan; bcresid_prototype, nlls = Val(false))
84+
jac_alg_forwarddiff = BVPJacobianAlgorithm(AutoSparse(AutoForwardDiff()))
85+
jac_alg_enzyme = BVPJacobianAlgorithm(AutoSparse(AutoEnzyme(
86+
mode = Enzyme.Forward, function_annotation = Enzyme.Duplicated)))
87+
jac_alg_mooncake = BVPJacobianAlgorithm(AutoSparse(AutoMooncake(;
88+
config = nothing)))
89+
for jac_alg in [jac_alg_forwarddiff, jac_alg_enzyme, jac_alg_mooncake]
90+
@test_nowarn sol = solve(
91+
prob, RadauIIa5(; jac_alg = jac_alg, nested_nlsolve = true), dt = 0.01)
92+
end
93+
end
94+
end

0 commit comments

Comments
 (0)