Skip to content

Commit 424f103

Browse files
Fix Hermite fallback for ExplicitRK without B_interp
When B_interp is nothing (e.g. default ExplicitRK with Dormand-Prince), fall back to Hermite interpolation instead of throwing. Also skip the custom _ode_addsteps! in this case so k contains only the 2 entries Hermite expects. For the out-of-place interpolant, compute bi weights inline to support ForwardDiff.Dual Θ values. Co-Authored-By: Chris Rackauckas <[email protected]> Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 0e9a06d commit 424f103

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

lib/OrdinaryDiffEqExplicitRK/src/explicit_rk_perform_step.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -330,25 +330,25 @@ function generic_rk_interpolant(Θ, dt, y₀, k, B_interp, bi; idxs = nothing, o
330330

331331
inv_dt_factor = order <= 1 ? one(dt) : inv(dt)^(order - 1)
332332

333-
# Fill pre-allocated bi buffer with polynomial weights
334-
for i in 1:nstages
335-
bi[i] = eval_poly_derivative(Θ, @view(B_interp[i, :]), order)
336-
end
337-
333+
# Compute weights inline (Θ may be a ForwardDiff.Dual, so we cannot
334+
# store into the pre-allocated Float64 buffer here).
335+
b1 = eval_poly_derivative(Θ, @view(B_interp[1, :]), order)
338336
return if isnothing(idxs)
339-
interp_sum = k[1] * bi[1]
337+
interp_sum = k[1] * b1
340338
for i in 2:nstages
341-
interp_sum = interp_sum + k[i] * bi[i]
339+
bval = eval_poly_derivative(Θ, @view(B_interp[i, :]), order)
340+
interp_sum = interp_sum + k[i] * bval
342341
end
343342
if order == 0
344343
y₀ + dt * interp_sum
345344
else
346345
interp_sum * inv_dt_factor
347346
end
348347
else
349-
interp_sum = k[1][idxs] * bi[1]
348+
interp_sum = k[1][idxs] * b1
350349
for i in 2:nstages
351-
interp_sum = interp_sum + k[i][idxs] * bi[i]
350+
bval = eval_poly_derivative(Θ, @view(B_interp[i, :]), order)
351+
interp_sum = interp_sum + k[i][idxs] * bval
352352
end
353353
if order == 0
354354
y₀[idxs] + dt * interp_sum

lib/OrdinaryDiffEqExplicitRK/src/interpolants.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ const ExplicitRKCacheTypes = Union{ExplicitRKCache, ExplicitRKConstantCache}
1616
always_calc_begin = false, allow_calc_end = true,
1717
force_calc_end = false
1818
)
19+
# Only compute all stage vectors when B_interp is available for generic RK
20+
# interpolation. Without B_interp, the default Hermite addsteps suffices.
21+
isnothing(cache.B_interp) && return nothing
22+
1923
(; A, c, stages) = cache
2024

2125
if length(k) < stages || always_calc_begin
@@ -37,6 +41,10 @@ end
3741
always_calc_begin = false, allow_calc_end = true,
3842
force_calc_end = false
3943
)
44+
# Only compute all stage vectors when B_interp is available for generic RK
45+
# interpolation. Without B_interp, the default Hermite addsteps suffices.
46+
isnothing(cache.tab.B_interp) && return nothing
47+
4048
(; kk, tmp, tab) = cache
4149
(; A, c, stages) = tab
4250

@@ -94,6 +102,7 @@ end
94102
end
95103

96104
# Generate interpolant methods for derivative orders 0-3
105+
# Only dispatch when B_interp is available; otherwise fall through to Hermite default.
97106
for order in 0:3
98107
@eval begin
99108
@muladd function _ode_interpolant(
@@ -102,6 +111,17 @@ for order in 0:3
102111
idxs, T::Type{Val{$order}}, differential_vars
103112
)
104113
B_interp = get_B_interp(cache)
114+
if isnothing(B_interp)
115+
# Fall back to Hermite interpolation (the generic default)
116+
dv = OrdinaryDiffEqCore.interpolation_differential_vars(
117+
differential_vars, y₀, idxs
118+
)
119+
return OrdinaryDiffEqCore.hermite_interpolant(
120+
Θ, dt, y₀, y₁, k,
121+
Val{cache isa OrdinaryDiffEqMutableCache},
122+
idxs, T, dv
123+
)
124+
end
105125
bi = get_bi(cache)
106126
return generic_rk_interpolant(Θ, dt, y₀, k, B_interp, bi; idxs = idxs, order = $order)
107127
end
@@ -112,6 +132,16 @@ for order in 0:3
112132
idxs, T::Type{Val{$order}}, differential_vars
113133
)
114134
B_interp = get_B_interp(cache)
135+
if isnothing(B_interp)
136+
# Fall back to Hermite interpolation (the generic default)
137+
dv = OrdinaryDiffEqCore.interpolation_differential_vars(
138+
differential_vars, y₀, idxs
139+
)
140+
return OrdinaryDiffEqCore.hermite_interpolant!(
141+
out, Θ, dt, y₀, y₁, k,
142+
idxs, T, dv
143+
)
144+
end
115145
bi = get_bi(cache)
116146
return generic_rk_interpolant!(out, Θ, dt, y₀, k, B_interp, bi; idxs = idxs, order = $order)
117147
end

0 commit comments

Comments
 (0)