Skip to content

Commit 4bf3de3

Browse files
Merge pull request #2869 from ChrisRackauckas-Claude/fix-tstop-overshoot-with-flag
Fix tstop overshoot error and super dense time event triggers
2 parents 10e8d33 + 3bcb852 commit 4bf3de3

File tree

6 files changed

+208
-30
lines changed

6 files changed

+208
-30
lines changed

lib/OrdinaryDiffEqCore/ext/OrdinaryDiffEqCoreMooncakeExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{
1111
typeof(OrdinaryDiffEqCore.increment_nf!), Vararg,
1212
}
1313
Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{
14-
typeof(OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!), Vararg,
14+
typeof(OrdinaryDiffEqCore.fixed_t_for_tstop_error!), Vararg,
1515
}
1616
Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{
1717
typeof(OrdinaryDiffEqCore.increment_accept!), Vararg,

lib/OrdinaryDiffEqCore/src/enzyme_rules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ function EnzymeCore.EnzymeRules.inactive_noinl(
55
end
66

77
function EnzymeCore.EnzymeRules.inactive_noinl(
8-
::typeof(fixed_t_for_floatingpoint_error!), args...
8+
::typeof(fixed_t_for_tstop_error!), args...
99
)
1010
return true
1111
end

lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -94,21 +94,67 @@ function last_step_failed(integrator::ODEIntegrator)
9494
return integrator.last_stepfail && !integrator.opts.adaptive
9595
end
9696

97+
# Accessor functions for tstop flag fields with fallbacks for non-ODE integrators
98+
# (e.g. DDEIntegrator in DelayDiffEq.jl which doesn't have these fields)
99+
_get_next_step_tstop(integrator::ODEIntegrator) = integrator.next_step_tstop
100+
_get_next_step_tstop(integrator) = false
101+
102+
function _set_tstop_flag!(integrator::ODEIntegrator, is_tstop::Bool, target = nothing)
103+
integrator.next_step_tstop = is_tstop
104+
if is_tstop && target !== nothing
105+
integrator.tstop_target = target
106+
end
107+
return nothing
108+
end
109+
_set_tstop_flag!(integrator, is_tstop::Bool, target = nothing) = nothing
110+
111+
_get_tstop_target(integrator::ODEIntegrator) = integrator.tstop_target
112+
97113
function modify_dt_for_tstops!(integrator)
98114
return if has_tstop(integrator)
99115
tdir_t = integrator.tdir * integrator.t
100116
tdir_tstop = first_tstop(integrator)
117+
distance_to_tstop = abs(tdir_tstop - tdir_t)
118+
101119
if integrator.opts.adaptive
102-
integrator.dt = integrator.tdir *
103-
min(abs(integrator.dt), abs(tdir_tstop - tdir_t)) # step! to the end
120+
original_dt = abs(integrator.dt)
121+
integrator.dtpropose = original_dt
122+
if original_dt < distance_to_tstop
123+
_set_tstop_flag!(integrator, false)
124+
else
125+
_set_tstop_flag!(
126+
integrator, true, integrator.tdir * tdir_tstop)
127+
end
128+
integrator.dt = integrator.tdir * min(original_dt, distance_to_tstop)
104129
elseif iszero(integrator.dtcache) && integrator.dtchangeable
105-
integrator.dt = integrator.tdir * abs(tdir_tstop - tdir_t)
130+
integrator.dt = integrator.tdir * distance_to_tstop
131+
_set_tstop_flag!(
132+
integrator, true, integrator.tdir * tdir_tstop)
106133
elseif integrator.dtchangeable && !integrator.force_stepfail
107134
# always try to step! with dtcache, but lower if a tstop
108135
# however, if force_stepfail then don't set to dtcache, and no tstop worry
136+
if abs(integrator.dtcache) < distance_to_tstop
137+
_set_tstop_flag!(integrator, false)
138+
else
139+
_set_tstop_flag!(
140+
integrator, true, integrator.tdir * tdir_tstop)
141+
end
109142
integrator.dt = integrator.tdir *
110-
min(abs(integrator.dtcache), abs(tdir_tstop - tdir_t)) # step! to the end
143+
min(abs(integrator.dtcache), distance_to_tstop)
144+
else
145+
_set_tstop_flag!(integrator, false)
111146
end
147+
else
148+
_set_tstop_flag!(integrator, false)
149+
end
150+
end
151+
152+
function handle_tstop_step!(integrator)
153+
return if integrator.t isa AbstractFloat && abs(integrator.dt) < eps(abs(integrator.t))
154+
# Skip perform_step! entirely for tiny dt
155+
integrator.accept_step = true
156+
else
157+
perform_step!(integrator, integrator.cache)
112158
end
113159
end
114160

@@ -183,7 +229,7 @@ function _savevalues!(integrator, force_save, reduce_size)::Tuple{Bool, Bool}
183229
integrator.opts.save_everystep &&
184230
(
185231
isempty(integrator.sol.t) ||
186-
(integrator.t !== integrator.sol.t[end]) &&
232+
(integrator.t !== integrator.sol.t[end] || iszero(integrator.dt)) &&
187233
(integrator.opts.save_end || integrator.t !== integrator.sol.prob.tspan[2])
188234
)
189235
)
@@ -344,6 +390,15 @@ function _loopfooter!(integrator)
344390
if integrator.accept_step # Accept
345391
increment_accept!(integrator.stats)
346392
integrator.last_stepfail = false
393+
integrator.tprev = integrator.t
394+
395+
if _get_next_step_tstop(integrator)
396+
# Step controller dt is overly pessimistic, since dt = time to tstop.
397+
# Restore the original dt so the controller proposes a reasonable next step.
398+
integrator.dt = integrator.dtpropose
399+
end
400+
integrator.t = fixed_t_for_tstop_error!(integrator, ttmp)
401+
347402
dtnew = DiffEqBase.value(
348403
step_accept_controller!(
349404
integrator,
@@ -352,8 +407,6 @@ function _loopfooter!(integrator)
352407
)
353408
) *
354409
oneunit(integrator.dt)
355-
integrator.tprev = integrator.t
356-
integrator.t = fixed_t_for_floatingpoint_error!(integrator, ttmp)
357410
calc_dt_propose!(integrator, dtnew)
358411
handle_callbacks!(integrator)
359412
else # Reject
@@ -362,7 +415,7 @@ function _loopfooter!(integrator)
362415
elseif !integrator.opts.adaptive #Not adaptive
363416
increment_accept!(integrator.stats)
364417
integrator.tprev = integrator.t
365-
integrator.t = fixed_t_for_floatingpoint_error!(integrator, ttmp)
418+
integrator.t = fixed_t_for_tstop_error!(integrator, ttmp)
366419
integrator.last_stepfail = false
367420
integrator.accept_step = true
368421
integrator.dtpropose = integrator.dt
@@ -406,18 +459,12 @@ function log_step!(progress_name, progress_id, progress_message, dt, u, p, t, ts
406459
)
407460
end
408461

409-
function fixed_t_for_floatingpoint_error!(integrator, ttmp)
410-
return if has_tstop(integrator)
411-
tstop = integrator.tdir * first_tstop(integrator)
412-
if abs(ttmp - tstop) <
413-
100eps(float(max(integrator.t, tstop) / oneunit(integrator.t))) *
414-
oneunit(integrator.t)
415-
tstop
416-
else
417-
ttmp
418-
end
462+
function fixed_t_for_tstop_error!(integrator, ttmp)
463+
if _get_next_step_tstop(integrator)
464+
_set_tstop_flag!(integrator, false)
465+
return _get_tstop_target(integrator)
419466
else
420-
ttmp
467+
return ttmp
421468
end
422469
end
423470

lib/OrdinaryDiffEqCore/src/integrators/type.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ mutable struct ODEIntegrator{
128128
force_stepfail::Bool
129129
last_stepfail::Bool
130130
just_hit_tstop::Bool
131+
next_step_tstop::Bool
132+
tstop_target::tType
131133
do_error_check::Bool
132134
event_last_time::Int
133135
vector_event_last_time::Int

lib/OrdinaryDiffEqCore/src/solve.jl

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,8 @@ function SciMLBase.__init(
629629
u_modified = false
630630
EEst = oneunit(EEstT) # https://github.com/JuliaPhysics/Measurements.jl/pull/135
631631
just_hit_tstop = false
632+
next_step_tstop = false
633+
tstop_target = zero(t)
632634
isout = false
633635
accept_step = false
634636
force_stepfail = false
@@ -678,7 +680,7 @@ function SciMLBase.__init(
678680
callback_cache,
679681
kshortsize, force_stepfail,
680682
last_stepfail,
681-
just_hit_tstop, do_error_check,
683+
just_hit_tstop, next_step_tstop, tstop_target, do_error_check,
682684
event_last_time,
683685
vector_event_last_time,
684686
last_event_error, accept_step,
@@ -741,14 +743,24 @@ end
741743

742744
function SciMLBase.solve!(integrator::ODEIntegrator)
743745
@inbounds while !isempty(integrator.opts.tstops)
744-
while integrator.tdir * integrator.t < first(integrator.opts.tstops)
746+
first_tstop = first(integrator.opts.tstops)
747+
while integrator.tdir * integrator.t < first_tstop
745748
loopheader!(integrator)
746749
if integrator.do_error_check && check_error!(integrator) != ReturnCode.Success
747750
return integrator.sol
748751
end
749-
perform_step!(integrator, integrator.cache)
752+
753+
# Use special tstop handling if flag is set, otherwise normal stepping
754+
if integrator.next_step_tstop
755+
handle_tstop_step!(integrator)
756+
else
757+
perform_step!(integrator, integrator.cache)
758+
end
759+
760+
should_exit = integrator.next_step_tstop
761+
750762
loopfooter!(integrator)
751-
if isempty(integrator.opts.tstops)
763+
if isempty(integrator.opts.tstops) || should_exit
752764
break
753765
end
754766
end
@@ -821,11 +833,11 @@ end
821833

822834
for t in tstops
823835
tdir_t = tdir * t
824-
tdir_t0 < tdir_t tdir_tf && push!(tstops_internal, tdir_t)
836+
tdir_t0 < tdir_t < tdir_tf && push!(tstops_internal, tdir_t)
825837
end
826838
for t in d_discontinuities
827839
tdir_t = tdir * t
828-
tdir_t0 < tdir_t tdir_tf && push!(tstops_internal, tdir_t)
840+
tdir_t0 < tdir_t < tdir_tf && push!(tstops_internal, tdir_t)
829841
end
830842
push!(tstops_internal, tdir_tf)
831843

@@ -842,11 +854,11 @@ function reinit_tstops!(::Type{T}, tstops_internal, tstops, d_discontinuities, t
842854

843855
for t in tstops
844856
tdir_t = tdir * t
845-
tdir_t0 < tdir_t tdir_tf && push!(tstops_internal, tdir_t)
857+
tdir_t0 < tdir_t < tdir_tf && push!(tstops_internal, tdir_t)
846858
end
847859
for t in d_discontinuities
848860
tdir_t = tdir * t
849-
tdir_t0 < tdir_t tdir_tf && push!(tstops_internal, tdir_t)
861+
tdir_t0 < tdir_t < tdir_tf && push!(tstops_internal, tdir_t)
850862
end
851863
return push!(tstops_internal, tdir_tf)
852864
end

test/interface/ode_tstops_tests.jl

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using OrdinaryDiffEq, Test, Random
1+
using OrdinaryDiffEq, Test, Random, StaticArrays, DiffEqCallbacks
22
import ODEProblemLibrary: prob_ode_linear
33
Random.seed!(100)
44

@@ -100,3 +100,120 @@ end
100100
sol2 = solve(prob2, Tsit5())
101101
@test 0.0:0.07:1.0 sol2.t
102102
end
103+
104+
# Tests for issue #2752: tstop overshoot errors with StaticArrays
105+
106+
@testset "StaticArrays vs Arrays with extreme precision" begin
107+
function precise_dynamics(u, p, t)
108+
x = @view u[1:2]
109+
v = @view u[3:4]
110+
dv = -0.01 * x + 1.0e-6 * sin(100 * t) * SVector{2}(1, 1)
111+
return SVector{4}(v[1], v[2], dv[1], dv[2])
112+
end
113+
114+
function precise_dynamics_array!(du, u, p, t)
115+
x = @view u[1:2]
116+
v = @view u[3:4]
117+
dv = -0.01 * x + 1.0e-6 * sin(100 * t) * [1, 1]
118+
du[1] = v[1]
119+
du[2] = v[2]
120+
du[3] = dv[1]
121+
du[4] = dv[2]
122+
end
123+
124+
u0_static = SVector{4}(1.0, -0.5, 0.01, 0.01)
125+
u0_array = [1.0, -0.5, 0.01, 0.01]
126+
tspan = (0.0, 2.0)
127+
tstops = [0.5, 1.0, 1.5]
128+
129+
prob_static = ODEProblem(precise_dynamics, u0_static, tspan)
130+
sol_static = solve(
131+
prob_static, Vern9(); reltol = 1.0e-12, abstol = 1.0e-15,
132+
tstops = tstops
133+
)
134+
@test SciMLBase.successful_retcode(sol_static)
135+
for tstop in tstops
136+
@test tstop sol_static.t
137+
end
138+
139+
prob_array = ODEProblem(precise_dynamics_array!, u0_array, tspan)
140+
sol_array = solve(
141+
prob_array, Vern9(); reltol = 1.0e-12, abstol = 1.0e-15,
142+
tstops = tstops
143+
)
144+
@test SciMLBase.successful_retcode(sol_array)
145+
for tstop in tstops
146+
@test tstop sol_array.t
147+
end
148+
149+
@test isapprox(sol_static(2.0), sol_array(2.0), rtol = 1.0e-10)
150+
end
151+
152+
@testset "Backward integration with tstop flags" begin
153+
function decay_ode(u, p, t)
154+
SA[-0.1 * u[1]]
155+
end
156+
157+
u0 = SVector{1}(1.0)
158+
tspan = (2.0, 0.0)
159+
tstops = [1.5, 1.0, 0.5]
160+
161+
prob = ODEProblem(decay_ode, u0, tspan)
162+
sol = solve(prob, Vern9(); tstops = tstops, reltol = 1.0e-12, abstol = 1.0e-15)
163+
@test SciMLBase.successful_retcode(sol)
164+
for tstop in tstops
165+
@test tstop sol.t
166+
end
167+
end
168+
169+
@testset "PresetTimeCallback with tstop flags" begin
170+
callback_times = Float64[]
171+
172+
function affect_preset!(integrator)
173+
push!(callback_times, integrator.t)
174+
integrator.u += 0.1 * integrator.u
175+
end
176+
177+
function simple_growth(u, p, t)
178+
SA[0.1 * u[1]]
179+
end
180+
181+
u0 = SA[1.0]
182+
tspan = (0.0, 3.0)
183+
critical_times = [0.5, 1.0, 1.5, 2.0, 2.5]
184+
185+
preset_cb = PresetTimeCallback(critical_times, affect_preset!)
186+
187+
prob = ODEProblem(simple_growth, u0, tspan)
188+
sol = solve(
189+
prob, Vern9(); tstops = critical_times, callback = preset_cb,
190+
reltol = 1.0e-10, abstol = 1.0e-12
191+
)
192+
193+
@test SciMLBase.successful_retcode(sol)
194+
for time in critical_times
195+
@test any(abs.(sol.t .- time) .< 1.0e-10)
196+
end
197+
@test length(callback_times) == length(critical_times)
198+
end
199+
200+
@testset "Multiple close tstops with StaticArrays" begin
201+
function oscillator(u, p, t)
202+
SVector{2}(u[2], -u[1])
203+
end
204+
205+
u0 = SVector{2}(1.0, 0.0)
206+
tspan = (0.0, 4.0)
207+
close_tstops = [
208+
1.0, 1.0 + 1.0e-14, 1.0 + 2.0e-14, 1.0 + 5.0e-14,
209+
2.0, 2.0 + 1.0e-15, 2.0 + 1.0e-14,
210+
3.0, 3.0 + 1.0e-13,
211+
]
212+
213+
prob = ODEProblem(oscillator, u0, tspan)
214+
sol = solve(prob, Vern9(); tstops = close_tstops, reltol = 1.0e-12, abstol = 1.0e-15)
215+
@test SciMLBase.successful_retcode(sol)
216+
for time in [1.0, 2.0, 3.0]
217+
@test any(abs.(sol.t .- time) .< 1.0e-10)
218+
end
219+
end

0 commit comments

Comments
 (0)