Skip to content

Commit 9541f1c

Browse files
authored
Merge branch 'master' into qqy/more_AD
2 parents c809a85 + 295799a commit 9541f1c

File tree

6 files changed

+301
-44
lines changed

6 files changed

+301
-44
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ Random = "1.10"
5858
ReTestItems = "1.29"
5959
Reexport = "1.2"
6060
SciMLBase = "2.82"
61+
Sparspak = "0.3.11"
6162
StaticArrays = "1.9.8"
6263
Test = "1.10"
6364
julia = "1.10"
@@ -76,8 +77,9 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
7677
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
7778
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
7879
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
80+
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
7981
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
8082
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8183

8284
[targets]
83-
test = ["Aqua", "DiffEqDevTools", "Hwloc", "InteractiveUtils", "JET", "LinearSolve", "NonlinearSolveFirstOrder", "ODEInterface", "OrdinaryDiffEq", "Pkg", "Random", "ReTestItems", "RecursiveArrayTools", "StaticArrays", "Test"]
85+
test = ["Aqua", "DiffEqDevTools", "Hwloc", "InteractiveUtils", "JET", "LinearSolve", "NonlinearSolveFirstOrder", "ODEInterface", "OrdinaryDiffEq", "Pkg", "Random", "ReTestItems", "RecursiveArrayTools", "Sparspak", "StaticArrays", "Test"]

lib/BoundaryValueDiffEqFIRK/src/collocation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ end
137137
K = get_tmp(k_discrete[i], u)
138138

139139
_nestprob = remake(nest_prob, p = nestprob_p)
140-
nestsol = solve(_nestprob, nest_nlsolve_alg; alg.nested_nlsolve_kwargs...)
140+
nestsol = __solve(_nestprob, nest_nlsolve_alg; alg.nested_nlsolve_kwargs...)
141141
@. K = nestsol.u
142142
@. residᵢ = yᵢ₊₁ - yᵢ
143143
__maybe_matmul!(residᵢ, nestsol.u, b, -h, T(1))
@@ -279,7 +279,7 @@ end
279279
nestprob_p[3:end] = yᵢ
280280

281281
_nestprob = remake(nest_prob, p = nestprob_p)
282-
nestsol = solve(_nestprob, nest_nlsolve_alg; alg.nested_nlsolve_kwargs...)
282+
nestsol = __solve(_nestprob, nest_nlsolve_alg; alg.nested_nlsolve_kwargs...)
283283

284284
@. residᵢ = yᵢ₊₁ - yᵢ
285285
__maybe_matmul!(residᵢ, nestsol.u, b, -h, T(1))

lib/BoundaryValueDiffEqFIRK/src/interpolation.jl

Lines changed: 178 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ end
5454

5555
for j in idx
5656
z = similar(cache.fᵢ₂_cache)
57-
interp_eval!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt)
57+
interpolant!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt, deriv)
5858
vals[j] = idxs !== nothing ? z[idxs] : z
5959
end
6060
return DiffEqArray(vals, tvals)
@@ -68,18 +68,106 @@ end
6868

6969
for j in idx
7070
z = similar(cache.fᵢ₂_cache)
71-
interp_eval!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt)
71+
interpolant!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt, deriv)
7272
vals[j] = z
7373
end
7474
end
7575

7676
@inline function interpolation(tval::Number, id::FIRKNestedInterpolation, idxs,
7777
deriv::D, p, continuity::Symbol = :left) where {D}
7878
z = similar(id.cache.fᵢ₂_cache)
79-
interp_eval!(z, id.cache, tval, id.cache.mesh, id.cache.mesh_dt)
79+
interpolant!(z, id.cache, tval, id.cache.mesh, id.cache.mesh_dt, deriv)
8080
return idxs !== nothing ? z[idxs] : z
8181
end
8282

83+
@inline function interpolant!(z::AbstractArray, cache::FIRKCacheNested{iip, T},
84+
t, mesh, mesh_dt, ::Type{Val{0}}) where {iip, T}
85+
(; f, ITU, nest_prob, alg) = cache
86+
(; q_coeff) = ITU
87+
88+
j = interval(mesh, t)
89+
h = mesh_dt[j]
90+
lf = (length(cache.y₀) - 1) / (length(cache.y) - 1)
91+
if lf > 1
92+
h *= lf
93+
end
94+
τ = (t - mesh[j])
95+
96+
nest_nlsolve_alg = __concrete_nonlinearsolve_algorithm(nest_prob, alg.nlsolve)
97+
nestprob_p = zeros(T, cache.M + 2)
98+
99+
yᵢ = copy(cache.y[j].du)
100+
yᵢ₊₁ = copy(cache.y[j + 1].du)
101+
102+
if iip
103+
dyᵢ = similar(yᵢ)
104+
dyᵢ₊₁ = similar(yᵢ₊₁)
105+
106+
f(dyᵢ, yᵢ, cache.p, mesh[j])
107+
f(dyᵢ₊₁, yᵢ₊₁, cache.p, mesh[j + 1])
108+
else
109+
dyᵢ = f(yᵢ, cache.p, mesh[j])
110+
dyᵢ₊₁ = f(yᵢ₊₁, cache.p, mesh[j + 1])
111+
end
112+
113+
nestprob_p[1] = mesh[j]
114+
nestprob_p[2] = mesh_dt[j]
115+
nestprob_p[3:end] .= yᵢ
116+
117+
_nestprob = remake(nest_prob, p = nestprob_p)
118+
nestsol = __solve(_nestprob, nest_nlsolve_alg; alg.nested_nlsolve_kwargs...)
119+
K = nestsol.u
120+
121+
z₁, z₁′ = eval_q(yᵢ, 0.5, h, q_coeff, K) # Evaluate q(x) at midpoints
122+
S_coeffs = get_S_coeffs(h, yᵢ, yᵢ₊₁, z₁, dyᵢ, dyᵢ₊₁, z₁′)
123+
124+
S_interpolate!(z, τ, S_coeffs)
125+
end
126+
127+
@inline function interpolant!(dz::AbstractArray, cache::FIRKCacheNested{iip, T},
128+
t, mesh, mesh_dt, ::Type{Val{1}}) where {iip, T}
129+
(; f, ITU, nest_prob, alg) = cache
130+
(; q_coeff) = ITU
131+
132+
j = interval(mesh, t)
133+
h = mesh_dt[j]
134+
lf = (length(cache.y₀) - 1) / (length(cache.y) - 1)
135+
if lf > 1
136+
h *= lf
137+
end
138+
τ = (t - mesh[j])
139+
140+
nest_nlsolve_alg = __concrete_nonlinearsolve_algorithm(nest_prob, alg.nlsolve)
141+
nestprob_p = zeros(T, cache.M + 2)
142+
143+
yᵢ = copy(cache.y[j].du)
144+
yᵢ₊₁ = copy(cache.y[j + 1].du)
145+
146+
if iip
147+
dyᵢ = similar(yᵢ)
148+
dyᵢ₊₁ = similar(yᵢ₊₁)
149+
150+
f(dyᵢ, yᵢ, cache.p, mesh[j])
151+
f(dyᵢ₊₁, yᵢ₊₁, cache.p, mesh[j + 1])
152+
else
153+
dyᵢ = f(yᵢ, cache.p, mesh[j])
154+
dyᵢ₊₁ = f(yᵢ₊₁, cache.p, mesh[j + 1])
155+
end
156+
157+
nestprob_p[1] = mesh[j]
158+
nestprob_p[2] = mesh_dt[j]
159+
nestprob_p[3:end] .= yᵢ
160+
161+
_nestprob = remake(nest_prob, p = nestprob_p)
162+
nestsol = __solve(_nestprob, nest_nlsolve_alg; alg.nested_nlsolve_kwargs...)
163+
K = nestsol.u
164+
165+
z₁, z₁′ = eval_q(yᵢ, 0.5, h, q_coeff, K)
166+
S_coeffs = get_S_coeffs(h, yᵢ, yᵢ₊₁, z₁, dyᵢ, dyᵢ₊₁, z₁′)
167+
168+
dS_interpolate!(dz, τ, S_coeffs)
169+
end
170+
83171
## Expanded
84172
@inline function interpolation(tvals, id::FIRKExpandInterpolation, idxs,
85173
deriv::D, p, continuity::Symbol = :left) where {D}
@@ -97,7 +185,7 @@ end
97185

98186
for j in idx
99187
z = similar(cache.fᵢ₂_cache)
100-
interp_eval!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt)
188+
interpolant!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt, deriv)
101189
vals[j] = idxs !== nothing ? z[idxs] : z
102190
end
103191
return DiffEqArray(vals, tvals)
@@ -111,18 +199,102 @@ end
111199

112200
for j in idx
113201
z = similar(cache.fᵢ₂_cache)
114-
interp_eval!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt)
202+
interpolant!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt, deriv)
115203
vals[j] = z
116204
end
117205
end
118206

119207
@inline function interpolation(tval::Number, id::FIRKExpandInterpolation, idxs,
120208
deriv::D, p, continuity::Symbol = :left) where {D}
121209
z = similar(id.cache.fᵢ₂_cache)
122-
interp_eval!(z, id.cache, tval, id.cache.mesh, id.cache.mesh_dt)
210+
interpolant!(z, id.cache, tval, id.cache.mesh, id.cache.mesh_dt, deriv)
123211
return idxs !== nothing ? z[idxs] : z
124212
end
125213

214+
@inline function interpolant!(z::AbstractArray, cache::FIRKCacheExpand{iip},
215+
t, mesh, mesh_dt, ::Type{Val{0}}) where {iip}
216+
j = interval(mesh, t)
217+
h = mesh_dt[j]
218+
lf = (length(cache.y₀) - 1) / (length(cache.y) - 1)
219+
if lf > 1
220+
h *= lf
221+
end
222+
τ = (t - mesh[j])
223+
224+
(; f, M, stage, p, ITU) = cache
225+
(; q_coeff) = ITU
226+
227+
K = safe_similar(cache.y[1].du, M, stage)
228+
229+
ctr_y = (j - 1) * (stage + 1) + 1
230+
231+
yᵢ = cache.y[ctr_y].du
232+
yᵢ₊₁ = cache.y[ctr_y + stage + 1].du
233+
234+
if iip
235+
dyᵢ = similar(yᵢ)
236+
dyᵢ₊₁ = similar(yᵢ₊₁)
237+
238+
f(dyᵢ, yᵢ, p, mesh[j])
239+
f(dyᵢ₊₁, yᵢ₊₁, p, mesh[j + 1])
240+
else
241+
dyᵢ = f(yᵢ, p, mesh[j])
242+
dyᵢ₊₁ = f(yᵢ₊₁, p, mesh[j + 1])
243+
end
244+
245+
# Load interpolation residual
246+
for jj in 1:stage
247+
K[:, jj] = cache.y[ctr_y + jj].du
248+
end
249+
250+
z₁, z₁′ = eval_q(yᵢ, 0.5, h, q_coeff, K) # Evaluate q(x) at midpoints
251+
S_coeffs = get_S_coeffs(h, yᵢ, yᵢ₊₁, z₁, dyᵢ, dyᵢ₊₁, z₁′)
252+
253+
S_interpolate!(z, τ, S_coeffs)
254+
end
255+
256+
@inline function interpolant!(dz::AbstractArray, cache::FIRKCacheExpand{iip},
257+
t, mesh, mesh_dt, ::Type{Val{1}}) where {iip}
258+
j = interval(mesh, t)
259+
h = mesh_dt[j]
260+
lf = (length(cache.y₀) - 1) / (length(cache.y) - 1)
261+
if lf > 1
262+
h *= lf
263+
end
264+
τ = (t - mesh[j])
265+
266+
(; f, M, stage, p, ITU) = cache
267+
(; q_coeff) = ITU
268+
269+
K = safe_similar(cache.y[1].du, M, stage)
270+
271+
ctr_y = (j - 1) * (stage + 1) + 1
272+
273+
yᵢ = cache.y[ctr_y].du
274+
yᵢ₊₁ = cache.y[ctr_y + stage + 1].du
275+
276+
if iip
277+
dyᵢ = similar(yᵢ)
278+
dyᵢ₊₁ = similar(yᵢ₊₁)
279+
280+
f(dyᵢ, yᵢ, p, mesh[j])
281+
f(dyᵢ₊₁, yᵢ₊₁, p, mesh[j + 1])
282+
else
283+
dyᵢ = f(yᵢ, p, mesh[j])
284+
dyᵢ₊₁ = f(yᵢ₊₁, p, mesh[j + 1])
285+
end
286+
287+
# Load interpolation residual
288+
for jj in 1:stage
289+
K[:, jj] = cache.y[ctr_y + jj].du
290+
end
291+
292+
z₁, z₁′ = eval_q(yᵢ, 0.5, h, q_coeff, K) # Evaluate q(x) at midpoints
293+
S_coeffs = get_S_coeffs(h, yᵢ, yᵢ₊₁, z₁, dyᵢ, dyᵢ₊₁, z₁′)
294+
295+
dS_interpolate!(dz, τ, S_coeffs)
296+
end
297+
126298
@inline __build_interpolation(cache::FIRKCacheExpand, u::AbstractVector) = FIRKExpandInterpolation(
127299
cache.mesh, u, cache)
128300
@inline __build_interpolation(cache::FIRKCacheNested, u::AbstractVector) = FIRKNestedInterpolation(

lib/BoundaryValueDiffEqFIRK/test/expanded/firk_basic_tests.jl

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,12 @@ end
283283
(-a * exp(-t * a) - a * exp((t - 2) * a)) / (1 - exp(-2 * a))]
284284
end
285285

286+
function prob_bvp_linear_analytic_derivative(u, λ, t)
287+
a = 1 / sqrt(λ)
288+
return [(-a * exp(-t * a) - a * exp((t - 2) * a)) / (1 - exp(-2 * a)),
289+
(exp(-a * t) - exp((t - 2) * a)) / (1 - exp(-2 * a))]
290+
end
291+
286292
function prob_bvp_linear_f!(du, u, p, t)
287293
du[1] = u[2]
288294
du[2] = 1 / p * u[1]
@@ -326,26 +332,51 @@ end
326332
end
327333

328334
@testset "Radau interpolations" begin
329-
@testset "RadauIIa$stage" for stage in (2, 3, 5, 7)
335+
@testset "Interpolation tests for RadauIIa$stage" for stage in (2, 3, 5, 7)
330336
@time sol = solve(prob_bvp_linear, radau_solver(Val(stage)); dt = 0.001)
331337
@test sol(0.001)[0.998687464, -1.312035941] atol=testTol
332338
@test sol(0.001; idxs = [1, 2])[0.998687464, -1.312035941] atol=testTol
333339
@test sol(0.001; idxs = 1)0.998687464 atol=testTol
334340
@test sol(0.001; idxs = 2)-1.312035941 atol=testTol
335341
end
342+
343+
@testset "Derivtive Interpolation tests for RadauIIa$stage" for stage in (
344+
2, 3, 5, 7)
345+
@time sol = solve(prob_bvp_linear, radau_solver(Val(stage)); dt = 0.001)
346+
sol_analytic = prob_bvp_linear_analytic(nothing, λ, 0.04)
347+
dsol_analytic = prob_bvp_linear_analytic_derivative(nothing, λ, 0.04)
348+
349+
@test sol(0.04, Val{0})sol_analytic atol=testTol
350+
@test sol(0.04, Val{1})dsol_analytic atol=testTol
351+
end
336352
end
337353

338354
@testset "LobattoIII interpolations" begin
339-
for (id, lobatto_solver) in zip(
340-
("a", "b", "c"), (lobattoIIIa_solver, lobattoIIIb_solver, lobattoIIIc_solver))
341-
begin
342-
@testset "LobattoIII$(id)$stage" for stage in (3, 4, 5)
343-
@time sol = solve(
344-
prob_bvp_linear, lobatto_solver(Val(stage)); dt = 0.001)
345-
@test sol(0.001)[0.998687464, -1.312035941] atol=testTol
346-
@test sol(0.001; idxs = [1, 2])[0.998687464, -1.312035941] atol=testTol
347-
@test sol(0.001; idxs = 1)0.998687464 atol=testTol
348-
@test sol(0.001; idxs = 2)-1.312035941 atol=testTol
355+
@testset "Interpolation tests for Lobatto" begin
356+
for (id, lobatto_solver) in zip(("a", "b", "c"),
357+
(lobattoIIIa_solver, lobattoIIIb_solver, lobattoIIIc_solver))
358+
begin
359+
@testset "Interpolation tests for LobattoIII$(id)$stage" for stage in (
360+
3, 4, 5)
361+
@time sol = solve(
362+
prob_bvp_linear, lobatto_solver(Val(stage)); dt = 0.001)
363+
@test sol(0.001)[0.998687464, -1.312035941] atol=testTol
364+
@test sol(0.001; idxs = [1, 2])[0.998687464, -1.312035941] atol=testTol
365+
@test sol(0.001; idxs = 1)0.998687464 atol=testTol
366+
@test sol(0.001; idxs = 2)-1.312035941 atol=testTol
367+
end
368+
369+
@testset "Derivative Interpolation tests for lobatto$(id)$stage" for stage in (
370+
3, 4, 5)
371+
@time sol = solve(
372+
prob_bvp_linear, lobatto_solver(Val(stage)); dt = 0.001)
373+
sol_analytic = prob_bvp_linear_analytic(nothing, λ, 0.04)
374+
dsol_analytic = prob_bvp_linear_analytic_derivative(
375+
nothing, λ, 0.04)
376+
377+
@test sol(0.04, Val{0})sol_analytic atol=testTol
378+
@test sol(0.04, Val{1})dsol_analytic atol=testTol
379+
end
349380
end
350381
end
351382
end

0 commit comments

Comments
 (0)