diff --git a/lib/OrdinaryDiffEqCore/ext/OrdinaryDiffEqCoreEnzymeCoreExt.jl b/lib/OrdinaryDiffEqCore/ext/OrdinaryDiffEqCoreEnzymeCoreExt.jl index 20537ea033..faa2d70204 100644 --- a/lib/OrdinaryDiffEqCore/ext/OrdinaryDiffEqCoreEnzymeCoreExt.jl +++ b/lib/OrdinaryDiffEqCore/ext/OrdinaryDiffEqCoreEnzymeCoreExt.jl @@ -6,7 +6,7 @@ function EnzymeCore.EnzymeRules.inactive_noinl( true end function EnzymeCore.EnzymeRules.inactive_noinl( - ::typeof(OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!), args...) + ::typeof(OrdinaryDiffEqCore.fixed_t_for_tstop_error!), args...) true end function EnzymeCore.EnzymeRules.inactive_noinl( diff --git a/lib/OrdinaryDiffEqCore/ext/OrdinaryDiffEqCoreMooncakeExt.jl b/lib/OrdinaryDiffEqCore/ext/OrdinaryDiffEqCoreMooncakeExt.jl index d328f9516d..dfdbcfc30c 100644 --- a/lib/OrdinaryDiffEqCore/ext/OrdinaryDiffEqCoreMooncakeExt.jl +++ b/lib/OrdinaryDiffEqCore/ext/OrdinaryDiffEqCoreMooncakeExt.jl @@ -8,7 +8,7 @@ Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{ Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{ typeof(OrdinaryDiffEqCore.SciMLBase.check_error), Any} Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{ - typeof(OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!), Any, Any} + typeof(OrdinaryDiffEqCore.fixed_t_for_tstop_error!), Any, Any} Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{ typeof(OrdinaryDiffEqCore.final_progress), Any} diff --git a/lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl b/lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl index c46f5d8c52..bc360b0ec6 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl @@ -76,20 +76,58 @@ function modify_dt_for_tstops!(integrator) if has_tstop(integrator) tdir_t = integrator.tdir * integrator.t tdir_tstop = first_tstop(integrator) + distance_to_tstop = abs(tdir_tstop - tdir_t) + + # Store the original dt to check if it gets significantly reduced + original_dt = abs(integrator.dt) + if integrator.opts.adaptive - integrator.dt = integrator.tdir * - min(abs(integrator.dt), abs(tdir_tstop - tdir_t)) # step! to the end + integrator.dtpropose = original_dt + if original_dt < distance_to_tstop + # Normal step, no tstop interference + integrator.next_step_tstop = false + else + # Distance is smaller, entering tstop snap mode + integrator.next_step_tstop = true + integrator.tstop_target = integrator.tdir * tdir_tstop + end + integrator.dt = integrator.tdir * min(original_dt, distance_to_tstop) elseif iszero(integrator.dtcache) && integrator.dtchangeable - integrator.dt = integrator.tdir * abs(tdir_tstop - tdir_t) + integrator.dt = integrator.tdir * distance_to_tstop + integrator.next_step_tstop = true + integrator.tstop_target = integrator.tdir * tdir_tstop elseif integrator.dtchangeable && !integrator.force_stepfail # always try to step! with dtcache, but lower if a tstop # however, if force_stepfail then don't set to dtcache, and no tstop worry - integrator.dt = integrator.tdir * - min(abs(integrator.dtcache), abs(tdir_tstop - tdir_t)) # step! to the end + if abs(integrator.dtcache) < distance_to_tstop + # Normal step with dtcache, no tstop interference + integrator.next_step_tstop = false + else + # Distance is smaller, entering tstop snap mode + integrator.next_step_tstop = true + integrator.tstop_target = integrator.tdir * tdir_tstop + end + integrator.dt = integrator.tdir * min(abs(integrator.dtcache), distance_to_tstop) + else + integrator.next_step_tstop = false end + else + integrator.next_step_tstop = false end end +function handle_tstop_step!(integrator) + if integrator.t isa AbstractFloat && abs(integrator.dt) < eps(abs(integrator.t)) + # Skip perform_step! entirely for tiny dt + integrator.accept_step = true + else + # Normal step + perform_step!(integrator, integrator.cache) + end + + # Flag will be reset in fixed_t_for_tstop_error! when t is updated +end + # Want to extend savevalues! for DDEIntegrator function savevalues!(integrator::ODEIntegrator, force_save = false, reduce_size = true) _savevalues!(integrator, force_save, reduce_size) @@ -149,7 +187,7 @@ function _savevalues!(integrator, force_save, reduce_size)::Tuple{Bool, Bool} end if force_save || (integrator.opts.save_everystep && (isempty(integrator.sol.t) || - (integrator.t !== integrator.sol.t[end]) && + (integrator.t !== integrator.sol.t[end] || iszero(integrator.dt)) && (integrator.opts.save_end || integrator.t !== integrator.sol.prob.tspan[2]) )) integrator.saveiter += 1 @@ -274,12 +312,20 @@ function _loopfooter!(integrator) if integrator.accept_step # Accept increment_accept!(integrator.stats) integrator.last_stepfail = false + integrator.tprev = integrator.t + + if integrator.next_step_tstop + # Step controller dt is overly pessimistic, since dt = time to tstop + # For example, if super dense time, dt = eps(t), so the next step is tiny + # Thus if snap to tstop, let the step controller assume dt was the pre-fixed version + integrator.dt = integrator.dtpropose + end + integrator.t = fixed_t_for_tstop_error!(integrator, ttmp) + dtnew = DiffEqBase.value(step_accept_controller!(integrator, integrator.alg, q)) * oneunit(integrator.dt) - integrator.tprev = integrator.t - integrator.t = fixed_t_for_floatingpoint_error!(integrator, ttmp) calc_dt_propose!(integrator, dtnew) handle_callbacks!(integrator) else # Reject @@ -288,7 +334,7 @@ function _loopfooter!(integrator) elseif !integrator.opts.adaptive #Not adaptive increment_accept!(integrator.stats) integrator.tprev = integrator.t - integrator.t = fixed_t_for_floatingpoint_error!(integrator, ttmp) + integrator.t = fixed_t_for_tstop_error!(integrator, ttmp) integrator.last_stepfail = false integrator.accept_step = true integrator.dtpropose = integrator.dt @@ -327,16 +373,12 @@ function log_step!(progress_name, progress_id, progress_message, dt, u, p, t, ts progress=(t-t1)/(t2-t1)) end -function fixed_t_for_floatingpoint_error!(integrator, ttmp) - if has_tstop(integrator) - tstop = integrator.tdir * first_tstop(integrator) - if abs(ttmp - tstop) < - 100eps(float(max(integrator.t, tstop) / oneunit(integrator.t))) * - oneunit(integrator.t) - tstop - else - ttmp - end +function fixed_t_for_tstop_error!(integrator, ttmp) + # If we're in tstop snap mode, use exact tstop target + if integrator.next_step_tstop + # Reset the flag now that we're snapping to tstop + integrator.next_step_tstop = false + return integrator.tstop_target else ttmp end @@ -468,10 +510,7 @@ function handle_tstop!(integrator) tdir_t = integrator.tdir * integrator.t tdir_tstop = first_tstop(integrator) if tdir_t == tdir_tstop - while tdir_t == tdir_tstop #remove all redundant copies - res = pop_tstop!(integrator) - has_tstop(integrator) ? (tdir_tstop = first_tstop(integrator)) : break - end + res = pop_tstop!(integrator) integrator.just_hit_tstop = true elseif tdir_t > tdir_tstop if !integrator.dtchangeable diff --git a/lib/OrdinaryDiffEqCore/src/integrators/type.jl b/lib/OrdinaryDiffEqCore/src/integrators/type.jl index 7dc9e1a9c6..0dd3358012 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/type.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/type.jl @@ -119,6 +119,8 @@ mutable struct ODEIntegrator{algType <: Union{OrdinaryDiffEqAlgorithm, DAEAlgori force_stepfail::Bool last_stepfail::Bool just_hit_tstop::Bool + next_step_tstop::Bool + tstop_target::tType do_error_check::Bool event_last_time::Int vector_event_last_time::Int diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index 6ddd866f87..b2e0346ec5 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -503,8 +503,10 @@ function SciMLBase.__init( u_modified = false EEst = EEstT(1) just_hit_tstop = false + next_step_tstop = false + tstop_target = zero(t) isout = false - accept_step = false + accept_step = true force_stepfail = false last_stepfail = false do_error_check = true @@ -541,7 +543,7 @@ function SciMLBase.__init( callback_cache, kshortsize, force_stepfail, last_stepfail, - just_hit_tstop, do_error_check, + just_hit_tstop, next_step_tstop, tstop_target, do_error_check, event_last_time, vector_event_last_time, last_event_error, accept_step, @@ -603,14 +605,24 @@ end function SciMLBase.solve!(integrator::ODEIntegrator) @inbounds while !isempty(integrator.opts.tstops) - while integrator.tdir * integrator.t < first(integrator.opts.tstops) + first_tstop = first(integrator.opts.tstops) + while integrator.tdir * integrator.t <= first_tstop loopheader!(integrator) if integrator.do_error_check && check_error!(integrator) != ReturnCode.Success return integrator.sol end - perform_step!(integrator, integrator.cache) + + # Use special tstop handling if flag is set, otherwise normal stepping + if integrator.next_step_tstop + handle_tstop_step!(integrator) + else + perform_step!(integrator, integrator.cache) + end + + should_exit = integrator.next_step_tstop + loopfooter!(integrator) - if isempty(integrator.opts.tstops) + if isempty(integrator.opts.tstops) || should_exit break end end @@ -662,11 +674,11 @@ end for t in tstops tdir_t = tdir * t - tdir_t0 < tdir_t ≤ tdir_tf && push!(tstops_internal, tdir_t) + tdir_t0 < tdir_t < tdir_tf && push!(tstops_internal, tdir_t) end for t in d_discontinuities tdir_t = tdir * t - tdir_t0 < tdir_t ≤ tdir_tf && push!(tstops_internal, tdir_t) + tdir_t0 < tdir_t < tdir_tf && push!(tstops_internal, tdir_t) end push!(tstops_internal, tdir_tf) diff --git a/test/interface/ode_tstops_tests.jl b/test/interface/ode_tstops_tests.jl index a911ecfe8a..fa063ee99a 100644 --- a/test/interface/ode_tstops_tests.jl +++ b/test/interface/ode_tstops_tests.jl @@ -1,7 +1,6 @@ -using OrdinaryDiffEq, Test, Random +using OrdinaryDiffEq, Test, Random, StaticArrays, DiffEqCallbacks import ODEProblemLibrary: prob_ode_linear Random.seed!(100) - @testset "Tstops Tests on the Interval [0, 1]" begin prob = prob_ode_linear @@ -13,9 +12,8 @@ Random.seed!(100) sol = solve(prob, RK4(), dt = 1 // 3, tstops = [1 / 2], d_discontinuities = [-1 / 2, 1 / 2, 3 / 2], adaptive = false) - @test sol.t == [0, 1 / 3, 1 / 2, 1 / 3 + 1 / 2, 1] + @test sol.t == [0, 1 / 3, 1 / 2, 1 / 2, 1 / 3 + 1 / 2, 1] - # TODO integrator = init(prob, RK4(), tstops = [1 / 5, 1 / 4, 1 / 3, 1 / 2, 3 / 4], adaptive = false) @@ -88,3 +86,308 @@ end sol2 = solve(prob2, Tsit5()) @test 0.0:0.07:1.0 ⊆ sol2.t end + +# Tests for issue #2752: tstop overshoot errors with StaticArrays + +@testset "StaticArrays vs Arrays with extreme precision" begin + # Test the specific case that was failing: extreme precision + StaticArrays + function precise_dynamics(u, p, t) + x = @view u[1:2] + v = @view u[3:4] + + # Electromagnetic-like dynamics + dv = -0.01 * x + 1e-6 * sin(100*t) * SVector{2}(1, 1) + + return SVector{4}(v[1], v[2], dv[1], dv[2]) + end + + function precise_dynamics_array!(du, u, p, t) + x = @view u[1:2] + v = @view u[3:4] + + dv = -0.01 * x + 1e-6 * sin(100*t) * [1, 1] + du[1] = v[1] + du[2] = v[2] + du[3] = dv[1] + du[4] = dv[2] + end + + # Initial conditions + u0_static = SVector{4}(1.0, -0.5, 0.01, 0.01) + u0_array = [1.0, -0.5, 0.01, 0.01] + tspan = (0.0, 2.0) + tstops = [0.5, 1.0, 1.5] + + # Test with extreme tolerances that originally caused issues + prob_static = ODEProblem(precise_dynamics, u0_static, tspan) + sol_static = solve(prob_static, Vern9(); reltol=1e-12, abstol=1e-15, + tstops=tstops) + @test SciMLBase.successful_retcode(sol_static) + for tstop in tstops + @test tstop ∈ sol_static.t + end + + prob_array = ODEProblem(precise_dynamics_array!, u0_array, tspan) + sol_array = solve(prob_array, Vern9(); reltol=1e-12, abstol=1e-15, + tstops=tstops) + @test SciMLBase.successful_retcode(sol_static) + for tstop in tstops + @test tstop ∈ sol_array.t + end + + # Solutions should be very close despite different array types + @test isapprox(sol_static(2.0), sol_array(2.0), rtol=1e-10) +end + +@testset "Duplicate tstops handling" begin + function simple_ode(u, p, t) + SA[0.1 * u[1]] + end + + u0 = SVector{1}(1.0) + tspan = (0.0, 2.0) + + # Test multiple identical tstops - should all be processed + duplicate_tstops = [0.5, 0.5, 0.5, 1.0, 1.0] + + prob = ODEProblem(simple_ode, u0, tspan) + sol = solve(prob, Vern9(); tstops=duplicate_tstops) + + @test SciMLBase.successful_retcode(sol) + + # Count how many times each tstop appears in solution + count_05 = count(t -> abs(t - 0.5) < 1e-12, sol.t) + count_10 = count(t -> abs(t - 1.0) < 1e-12, sol.t) + + # Should handle all duplicate tstops (though may not save all due to deduplication) + @test count_05 >= 1 # At least one 0.5 + @test count_10 >= 1 # At least one 1.0 + + # Test with StaticArrays too + prob_static = ODEProblem(simple_ode, u0, tspan) + sol_static = solve(prob_static, Vern9(); tstops=duplicate_tstops) + @test SciMLBase.successful_retcode(sol_static) +end + +@testset "PresetTimeCallback with identical times" begin + # Test PresetTimeCallback scenarios where callbacks are set at same times as tstops + + event_times = Float64[] + callback_times = Float64[] + + function affect_preset!(integrator) + push!(callback_times, integrator.t) + integrator.u += 0.1* integrator.u # Small modification + end + + function simple_growth(u, p, t) + SA[0.1 * u[1]] + end + + u0 = SA[1.0] + tspan = (0.0, 3.0) + + # Define times where both tstops and callbacks should trigger + critical_times = [0.5, 1.0, 1.5, 2.0, 2.5] + + # Create PresetTimeCallback at the same times as tstops + preset_cb = PresetTimeCallback(critical_times, affect_preset!) + + prob = ODEProblem(simple_growth, u0, tspan) + sol = solve(prob, Vern9(); tstops=critical_times, callback=preset_cb, + reltol=1e-10, abstol=1e-12) + + @test SciMLBase.successful_retcode(sol) + + # Verify all tstops were hit + for time in critical_times + @test any(abs.(sol.t .- time) .< 1e-10) + end + + # Verify all callbacks were triggered + @test length(callback_times) == 2*length(critical_times) + for time in critical_times + @test any(abs.(callback_times .- time) .< 1e-10) + end + + # Test the same with regular arrays + u0_array = [1.0] + callback_times_array = Float64[] + + function affect_preset_array!(integrator) + push!(callback_times_array, integrator.t) + integrator.u[1] += 0.1 + end + + function simple_growth_array!(du, u, p, t) + du[1] = 0.1 * u[1] + end + + preset_cb_array = PresetTimeCallback(critical_times, affect_preset_array!) + + prob_array = ODEProblem(simple_growth_array!, u0_array, tspan) + sol_array = solve(prob_array, Vern9(); tstops=critical_times, callback=preset_cb_array, + reltol=1e-10, abstol=1e-12) + + @test SciMLBase.successful_retcode(sol_array) + @test length(callback_times_array) == 2*length(critical_times) + + # Both should have triggered all events + @test length(callback_times) == length(callback_times_array) == 2*length(critical_times) +end + +@testset "Super Dense Callback Times" begin + event_times = Float64[] + callback_times = Float64[] + + function condition(u,t,integrator) + t == 0.5 || t == 1.0 + end + + function affect_preset!(integrator) + push!(callback_times, integrator.t) + integrator.u[1] += 1.0 # Small modification + end + + function simple_growth(u, p, t) + [0.0] + end + + u0 = [1.0] + tspan = (0.0, 3.0) + + # Define times where both tstops and callbacks should trigger + critical_times = [0.5, 0.5, 0.5, 1.0, 1.0] + + # Create PresetTimeCallback at the same times as tstops + preset_cb = DiscreteCallback(condition, affect_preset!) + + prob = ODEProblem(simple_growth, u0, tspan) + + sol = solve(prob, Vern9(); tstops=critical_times, dt = 0.1, + reltol=1e-10, abstol=1e-12) + @test sol.t[3:5] == [0.5, 0.5, 0.5] + + # Tests that after super dense time, dt is not trivially small + @test sol.t[6:8] == [1.0, 1.0, 3.0] + + sol = solve(prob, Vern9(); tstops=critical_times, callback=preset_cb, + reltol=1e-10, abstol=1e-12, save_everystep=false) + + # Test that the callback is called at every repeat 0.5 and 1.0 + @test sol[end] == [6.0] +end + +@testset "Tiny tstop step handling" begin + # Test cases where tstop is very close to current time (dt < eps(t)) + function test_ode(u, p, t) + SA[0.01 * u[1]] + end + + u0 = SVector{1}(1.0) + tspan = (0.0, 1.0) + + # Create tstop very close to start time (would cause tiny dt) + tiny_tstops = [1e-15, 1e-14, 1e-13] + + for tiny_tstop in tiny_tstops + prob = ODEProblem(test_ode, u0, tspan) + sol = solve(prob, Vern9(); tstops=[tiny_tstop]) + @test SciMLBase.successful_retcode(sol) + @test any(abs.(sol.t .- tiny_tstop) .< 1e-14) # Should handle tiny tstop correctly + end + + prob = ODEProblem(test_ode, u0, tspan) + sol = solve(prob, Vern9(); tstops=tiny_tstops) + @test all(t ∈ sol.t for t in tiny_tstops) +end + +@testset "Multiple close tstops with StaticArrays" begin + # Test with multiple tstops that are very close together - stress test the flag logic + function oscillator(u, p, t) + SVector{2}(u[2], -u[1]) # Simple harmonic oscillator + end + + u0 = SVector{2}(1.0, 0.0) + tspan = (0.0, 4.0) + + # Multiple tstops close together (within floating-point precision range) + close_tstops = [1.0, 1.0 + 1e-14, 1.0 + 2e-14, 1.0 + 5e-14, + 2.0, 2.0 + 1e-15, 2.0 + 1e-14, + 3.0, 3.0 + 1e-13] + + prob = ODEProblem(oscillator, u0, tspan) + sol = solve(prob, Vern9(); tstops=close_tstops, reltol=1e-12, abstol=1e-15) + + @test SciMLBase.successful_retcode(sol) + + # Should handle all close tstops without error + # (Some might be deduplicated, but no errors should occur) + unique_times = [1.0, 2.0, 3.0] + for time in unique_times + @test any(abs.(sol.t .- time) .< 1e-10) # At least hit the main times + end +end + +@testset "Backward integration with tstop flags" begin + # Test that the fix works for backward time integration + function decay_ode(u, p, t) + SA[-0.1 * u[1]] + end + + u0 = SVector{1}(1.0) + tspan = (2.0, 0.0) # Backward integration + tstops = [1.5, 1.0, 0.5] + + prob = ODEProblem(decay_ode, u0, tspan) + sol = solve(prob, Vern9(); tstops=tstops, reltol=1e-12, abstol=1e-15) + + @test SciMLBase.successful_retcode(sol) + for tstop in tstops + @test tstop ∈ sol.t + end +end + +@testset "Continuous callbacks during tstop steps" begin + # Test that continuous callbacks work properly with tstop flag mechanism + + crossing_times = Float64[] + + function affect_continuous!(integrator) + push!(crossing_times, integrator.t) + end + + function condition_continuous(u, t, integrator) + u[1] - 0.5 # Crosses when u[1] = 0.5 + end + + function exponential_growth(u, p, t) + [0.2 * u[1]] # Exponential growth + end + + u0 = [0.1] # Start below 0.5 + tspan = (0.0, 10.0) + tstops = [2.0, 4.0, 6.0, 8.0] # Regular tstops + + continuous_cb = ContinuousCallback(condition_continuous, affect_continuous!) + + prob = ODEProblem(exponential_growth, u0, tspan) + sol = solve(prob, Vern9(); tstops=tstops, callback=continuous_cb, + reltol=1e-10, abstol=1e-12) + + @test SciMLBase.successful_retcode(sol) + + # Should hit all tstops + for tstop in tstops + @test tstop ∈ sol.t + end + + # Should also detect continuous callback crossings + @test length(crossing_times) > 0 # At least one crossing detected + + # Verify crossings are at correct value + for crossing_time in crossing_times + u_at_crossing = sol(crossing_time) + @test abs(u_at_crossing[1] - 0.5) < 1e-8 # Should be very close to 0.5 + end +end \ No newline at end of file