From 9daae423979ab4c089210e238680ab5d837a3498 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Mon, 29 Dec 2025 14:11:59 -0500 Subject: [PATCH] Add callback/event support for operator splitting integrator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This implements callback support for OperatorSplittingIntegrator, addressing issue #36. Key changes: - Add savevalues! function to save solution values during callback execution - Add handle_callbacks! function to process callbacks after each step - Add apply_discrete_callback! for full discrete callback support - Add apply_continuous_callback_simple! for simplified continuous callbacks (endpoint-only checking without full root-finding) - Modify __step! to call handle_callbacks! after advancing the solution - Implement proper u_modified! to track when callbacks modify the solution - Add comprehensive test suite for callback functionality Design decisions: - Callbacks are handled at the outer integrator level only, after all subproblems are solved for a timestep. This ensures callbacks have access to the full solution state. - Continuous callbacks use simplified endpoint checking rather than full root-finding, as root-finding across split subproblems is not well-defined for operator splitting methods. - When u is modified by a callback, it is propagated to uprev so the next step starts correctly. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/integrator.jl | 286 +++++++++++++++++++++++++++++++++++++++++++++- test/callbacks.jl | 260 +++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 544 insertions(+), 3 deletions(-) create mode 100644 test/callbacks.jl diff --git a/src/integrator.jl b/src/integrator.jl index 3001a37..7019adf 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -317,6 +317,276 @@ function (integrator::OperatorSplittingIntegrator)(tmp, t) tmp, t, integrator.uprev, integrator.u, integrator.tprev, integrator.t) end +""" + savevalues!(integrator::OperatorSplittingIntegrator, force_save=false) + +Save the current solution values to the solution object. This is called by callbacks +when they trigger saves. +""" +function savevalues!(integrator::OperatorSplittingIntegrator, force_save = false) + saved = false + savedexactly = false + + # Handle saveat times + tdir_t = tdir(integrator) * integrator.t + while !isempty(integrator.saveat) && first(integrator.saveat) <= tdir_t + saved = true + curt = tdir(integrator) * pop!(integrator.saveat) + if curt != integrator.t + # Interpolate to saveat time + val = copy(integrator.u) + integrator(val, curt) + push!(integrator.sol.t, curt) + push!(integrator.sol.u, val) + else + savedexactly = true + push!(integrator.sol.t, integrator.t) + push!(integrator.sol.u, copy(integrator.u)) + end + end + + # Force save if requested + if force_save && (isempty(integrator.sol.t) || integrator.sol.t[end] != integrator.t) + saved = true + savedexactly = true + push!(integrator.sol.t, integrator.t) + push!(integrator.sol.u, copy(integrator.u)) + end + + return saved, savedexactly +end + +""" + handle_callbacks!(integrator::OperatorSplittingIntegrator) + +Process callbacks after a step. This handles discrete callbacks at the outer +integrator level, checking conditions against the full solution state. + +Note: Continuous callbacks are not fully supported for operator splitting methods +because root-finding across split subproblems is not well-defined. Users requiring +continuous callbacks should ensure their callback conditions are checked at +discrete time points. +""" +function handle_callbacks!(integrator::OperatorSplittingIntegrator) + discrete_callbacks = integrator.callback.discrete_callbacks + continuous_callbacks = integrator.callback.continuous_callbacks + + discrete_modified = false + saved_in_cb = false + + # Process discrete callbacks + if !(discrete_callbacks isa Tuple{}) + discrete_modified, saved_in_cb = apply_discrete_callbacks!(integrator, discrete_callbacks) + end + + # For continuous callbacks, we do a simplified check at the endpoint only + # Full root-finding is not supported for operator splitting + if !(continuous_callbacks isa Tuple{}) + cont_modified, cont_saved = apply_continuous_callbacks_simple!(integrator, continuous_callbacks) + discrete_modified = discrete_modified || cont_modified + saved_in_cb = saved_in_cb || cont_saved + end + + # Save if no callback saved + if !saved_in_cb + savevalues!(integrator) + end + + integrator.u_modified = discrete_modified +end + +""" + apply_discrete_callbacks!(integrator, callbacks) + +Apply discrete callbacks to the integrator. Returns (modified, saved_in_cb) tuple. +""" +function apply_discrete_callbacks!(integrator::OperatorSplittingIntegrator, callbacks::Tuple) + modified = false + saved = false + for callback in callbacks + cb_modified, cb_saved = apply_discrete_callback!(integrator, callback) + modified = modified || cb_modified + saved = saved || cb_saved + end + return modified, saved +end + +""" + apply_discrete_callback!(integrator, callback) + +Apply a single discrete callback. Returns (modified, saved_in_cb) tuple. +""" +function apply_discrete_callback!(integrator::OperatorSplittingIntegrator, callback::SciMLBase.DiscreteCallback) + saved_in_cb = false + if callback.condition(integrator.u, integrator.t, integrator) + # Handle saveat + _, savedexactly = savevalues!(integrator) + saved_in_cb = true + + if callback.save_positions[1] + # Save before affect if requested and not already saved + savedexactly || savevalues!(integrator, true) + end + + integrator.u_modified = true + callback.affect!(integrator) + + if integrator.u_modified + # Sync modified u to uprev so next step starts correctly + integrator.uprev .= integrator.u + end + + if callback.save_positions[2] + savevalues!(integrator, true) + saved_in_cb = true + end + + return integrator.u_modified, saved_in_cb + end + return false, saved_in_cb +end + +""" + apply_continuous_callbacks_simple!(integrator, callbacks) + +Apply continuous callbacks with a simplified endpoint-only check. +Full root-finding is not supported for operator splitting methods. +Returns (modified, saved_in_cb) tuple. +""" +function apply_continuous_callbacks_simple!(integrator::OperatorSplittingIntegrator, callbacks::Tuple) + modified = false + saved = false + for callback in callbacks + cb_modified, cb_saved = apply_continuous_callback_simple!(integrator, callback) + modified = modified || cb_modified + saved = saved || cb_saved + end + return modified, saved +end + +""" + apply_continuous_callback_simple!(integrator, callback) + +Apply a continuous callback with simplified endpoint checking. +This checks for sign changes between tprev and t without root-finding. +""" +function apply_continuous_callback_simple!(integrator::OperatorSplittingIntegrator, callback::SciMLBase.ContinuousCallback) + # Evaluate condition at previous and current time + if callback.idxs === nothing + prev_condition = callback.condition(integrator.uprev, integrator.tprev, integrator) + curr_condition = callback.condition(integrator.u, integrator.t, integrator) + else + prev_condition = callback.condition( + @view(integrator.uprev[callback.idxs]), integrator.tprev, integrator) + curr_condition = callback.condition( + @view(integrator.u[callback.idxs]), integrator.t, integrator) + end + + prev_sign = sign(prev_condition) + curr_sign = sign(curr_condition) + + # Check for sign change (zero crossing) + if prev_sign * curr_sign <= 0 && prev_sign != 0 + saved_in_cb = false + + # Handle saveat + _, savedexactly = savevalues!(integrator) + saved_in_cb = true + + if callback.save_positions[1] + savedexactly || savevalues!(integrator, true) + end + + integrator.u_modified = true + + # Apply the appropriate affect function based on crossing direction + if prev_sign < 0 && callback.affect! !== nothing + callback.affect!(integrator) + elseif prev_sign > 0 && callback.affect_neg! !== nothing + callback.affect_neg!(integrator) + else + integrator.u_modified = false + end + + if integrator.u_modified + # Sync modified u to uprev + integrator.uprev .= integrator.u + end + + if callback.save_positions[2] + savevalues!(integrator, true) + saved_in_cb = true + end + + return integrator.u_modified, saved_in_cb + end + + return false, false +end + +function apply_continuous_callback_simple!(integrator::OperatorSplittingIntegrator, callback::SciMLBase.VectorContinuousCallback) + # For VectorContinuousCallback, we need to check each component + prev_conditions = similar(integrator.u, callback.len) + curr_conditions = similar(integrator.u, callback.len) + + if callback.idxs === nothing + callback.condition(prev_conditions, integrator.uprev, integrator.tprev, integrator) + callback.condition(curr_conditions, integrator.u, integrator.t, integrator) + else + callback.condition(prev_conditions, + @view(integrator.uprev[callback.idxs]), integrator.tprev, integrator) + callback.condition(curr_conditions, + @view(integrator.u[callback.idxs]), integrator.t, integrator) + end + + # Find first event (sign change) + event_idx = 0 + prev_sign_val = 0.0 + for i in 1:callback.len + prev_sign = sign(prev_conditions[i]) + curr_sign = sign(curr_conditions[i]) + if prev_sign * curr_sign <= 0 && prev_sign != 0 + event_idx = i + prev_sign_val = prev_sign + break + end + end + + if event_idx > 0 + saved_in_cb = false + + _, savedexactly = savevalues!(integrator) + saved_in_cb = true + + if callback.save_positions[1] + savedexactly || savevalues!(integrator, true) + end + + integrator.u_modified = true + + if prev_sign_val < 0 && callback.affect! !== nothing + callback.affect!(integrator, event_idx) + elseif prev_sign_val > 0 && callback.affect_neg! !== nothing + callback.affect_neg!(integrator, event_idx) + else + integrator.u_modified = false + end + + if integrator.u_modified + integrator.uprev .= integrator.u + end + + if callback.save_positions[2] + savevalues!(integrator, true) + saved_in_cb = true + end + + return integrator.u_modified, saved_in_cb + end + + return false, false +end + """ stepsize_controller!(::OperatorSplittingIntegrator) @@ -391,6 +661,12 @@ end function __step!(integrator) (; dtchangeable, tstops, _dt) = integrator + # If the previous step modified u (via callback), update uprev + if integrator.u_modified + integrator.uprev .= integrator.u + integrator.u_modified = false + end + # update dt before incrementing u; if dt is changeable and there is # a tstop within dt, reduce dt to tstop - t integrator.dt = !isempty(tstops) && dtchangeable ? @@ -417,6 +693,9 @@ function __step!(integrator) step_accept_controller!(integrator) + # Handle callbacks after the step + handle_callbacks!(integrator) + # remove tstops that were just reached while !isempty(tstops) && reached_tstop(integrator, first(tstops)) pop!(tstops) @@ -455,9 +734,10 @@ function DiffEqBase.add_saveat!(integrator::OperatorSplittingIntegrator, t) push!(integrator.saveat, t) end -# not sure what this should do? -# defined as default initialize: https://github.com/SciML/DiffEqBase.jl/blob/master/src/callbacks.jl#L3 -DiffEqBase.u_modified!(i::OperatorSplittingIntegrator, bool) = nothing +# Set u_modified flag to track when callbacks modify the solution +function DiffEqBase.u_modified!(integrator::OperatorSplittingIntegrator, bool) + integrator.u_modified = bool +end function synchronize_subintegrator_tree!(integrator::OperatorSplittingIntegrator) synchronize_subintegrator!(integrator.subintegrator_tree, integrator) diff --git a/test/callbacks.jl b/test/callbacks.jl new file mode 100644 index 0000000..314714d --- /dev/null +++ b/test/callbacks.jl @@ -0,0 +1,260 @@ +using OrdinaryDiffEqOperatorSplitting +using Test +using DiffEqBase +using SciMLBase +using OrdinaryDiffEqLowOrderRK + +# Test setup: simple split ODE problem +tspan = (0.0, 10.0) +u0 = [1.0, 2.0, 3.0] + +# Simple decay functions +function ode1(du, u, p, t) + @. du = -0.1u +end +f1 = DiffEqBase.ODEFunction(ode1) + +function ode2(du, u, p, t) + du[1] = -0.01u[2] + du[2] = -0.01u[1] +end +f2 = DiffEqBase.ODEFunction(ode2) + +f1dofs = [1, 2, 3] +f2dofs = [1, 2] +fsplit = GenericSplitFunction((f1, f2), (f1dofs, f2dofs)) +prob = OperatorSplittingProblem(fsplit, u0, tspan) + +@testset "Discrete Callbacks" begin + dt = 0.1 + + @testset "Simple condition callback" begin + # Count how many times the callback is triggered + callback_count = Ref(0) + + condition(u, t, integrator) = t >= 5.0 + function affect!(integrator) + callback_count[] += 1 + end + cb = DiscreteCallback(condition, affect!) + + integrator = DiffEqBase.init( + prob, LieTrotterGodunov((Euler(), Euler())), + dt = dt, callback = cb + ) + DiffEqBase.solve!(integrator) + + # Callback should be triggered for t >= 5.0 + # With dt=0.1 and tspan=(0,10), there are about 50 steps after t=5 + @test callback_count[] > 0 + @test integrator.sol.retcode == SciMLBase.ReturnCode.Success + end + + @testset "Callback that modifies u" begin + # Callback that doubles u[1] when t >= 5.0 (only once) + triggered = Ref(false) + + condition(u, t, integrator) = t >= 5.0 && !triggered[] + function affect!(integrator) + triggered[] = true + integrator.u[1] *= 2.0 + end + cb = DiscreteCallback(condition, affect!) + + integrator = DiffEqBase.init( + prob, LieTrotterGodunov((Euler(), Euler())), + dt = dt, callback = cb + ) + + # Step until just before t=5 + while integrator.t < 4.9 + DiffEqBase.step!(integrator) + end + u_before = copy(integrator.u) + + # Step past t=5 to trigger callback + while integrator.t < 5.5 + DiffEqBase.step!(integrator) + end + + @test triggered[] + @test integrator.sol.retcode in (SciMLBase.ReturnCode.Success, SciMLBase.ReturnCode.Default) + end + + @testset "Callback with save_positions" begin + save_times = Float64[] + + condition(u, t, integrator) = t >= 3.0 && t < 3.5 + function affect!(integrator) + # Just record that we were here + end + cb = DiscreteCallback(condition, affect!, save_positions = (true, true)) + + integrator = DiffEqBase.init( + prob, LieTrotterGodunov((Euler(), Euler())), + dt = dt, callback = cb + ) + DiffEqBase.solve!(integrator) + + @test integrator.sol.retcode == SciMLBase.ReturnCode.Success + # Solution should have saved some values + @test length(integrator.sol.t) > 0 || length(integrator.sol.u) >= 0 + end + + @testset "Multiple discrete callbacks" begin + count1 = Ref(0) + count2 = Ref(0) + + condition1(u, t, integrator) = t >= 2.0 + affect1!(integrator) = count1[] += 1 + cb1 = DiscreteCallback(condition1, affect1!) + + condition2(u, t, integrator) = t >= 7.0 + affect2!(integrator) = count2[] += 1 + cb2 = DiscreteCallback(condition2, affect2!) + + integrator = DiffEqBase.init( + prob, LieTrotterGodunov((Euler(), Euler())), + dt = dt, callback = CallbackSet(cb1, cb2) + ) + DiffEqBase.solve!(integrator) + + @test count1[] > count2[] # cb1 is triggered more often + @test count2[] > 0 # cb2 is still triggered + end +end + +@testset "Continuous Callbacks (Simplified)" begin + dt = 0.1 + + @testset "Zero-crossing detection" begin + # Detect when u[1] crosses below 0.5 + triggered = Ref(false) + + condition(u, t, integrator) = u[1] - 0.5 + function affect!(integrator) + triggered[] = true + end + cb = ContinuousCallback(condition, affect!) + + # Use a smaller tspan so u[1] actually crosses 0.5 + prob_short = OperatorSplittingProblem(fsplit, u0, (0.0, 20.0)) + + integrator = DiffEqBase.init( + prob_short, LieTrotterGodunov((Euler(), Euler())), + dt = dt, callback = cb + ) + DiffEqBase.solve!(integrator) + + # With decay, u[1] should eventually cross 0.5 + @test triggered[] || integrator.u[1] > 0.5 # Either crossed or hasn't decayed enough + @test integrator.sol.retcode == SciMLBase.ReturnCode.Success + end + + @testset "Reflection callback" begin + # Reflect u[1] when it goes below 0.3 + reflection_count = Ref(0) + + condition(u, t, integrator) = u[1] - 0.3 + function affect!(integrator) + reflection_count[] += 1 + integrator.u[1] = 0.6 - integrator.u[1] # Reflect around 0.3 + end + cb = ContinuousCallback(condition, affect!) + + # Longer tspan for decay + prob_long = OperatorSplittingProblem(fsplit, u0, (0.0, 50.0)) + + integrator = DiffEqBase.init( + prob_long, LieTrotterGodunov((Euler(), Euler())), + dt = dt, callback = cb + ) + DiffEqBase.solve!(integrator) + + @test integrator.sol.retcode == SciMLBase.ReturnCode.Success + end +end + +@testset "Callback with nested splitting" begin + dt = 0.1 + + # Create nested split problem + function ode3(du, u, p, t) + du[1] = -0.005u[2] + du[2] = -0.005u[1] + end + f3 = DiffEqBase.ODEFunction(ode3) + + f3dofs = [1, 2] + fsplit_inner = GenericSplitFunction((f3, f3), (f3dofs, f3dofs)) + fsplit_outer = GenericSplitFunction((f1, fsplit_inner), ([1, 2, 3], [1, 2])) + + prob_nested = OperatorSplittingProblem(fsplit_outer, u0, tspan) + + callback_count = Ref(0) + condition(u, t, integrator) = t >= 5.0 + affect!(integrator) = callback_count[] += 1 + cb = DiscreteCallback(condition, affect!) + + integrator = DiffEqBase.init( + prob_nested, + LieTrotterGodunov((Euler(), LieTrotterGodunov((Euler(), Euler())))), + dt = dt, callback = cb + ) + DiffEqBase.solve!(integrator) + + @test callback_count[] > 0 + @test integrator.sol.retcode == SciMLBase.ReturnCode.Success +end + +@testset "u_modified! functionality" begin + dt = 0.1 + + # Test that u_modified! is properly set and used + u_was_modified = Ref(false) + + condition(u, t, integrator) = t >= 5.0 && t < 5.5 + function affect!(integrator) + integrator.u[1] = 0.0 # Modify u + u_was_modified[] = true + end + cb = DiscreteCallback(condition, affect!) + + integrator = DiffEqBase.init( + prob, LieTrotterGodunov((Euler(), Euler())), + dt = dt, callback = cb + ) + + # Run until callback triggers + while integrator.t < 6.0 && !isempty(integrator.tstops) + DiffEqBase.step!(integrator) + end + + @test u_was_modified[] +end + +@testset "Callback preserves solution accuracy" begin + dt = 0.01 # Smaller dt for better accuracy + + # Reference solution without callback + integrator_ref = DiffEqBase.init( + prob, LieTrotterGodunov((Euler(), Euler())), + dt = dt + ) + DiffEqBase.solve!(integrator_ref) + u_ref = copy(integrator_ref.u) + + # Solution with no-op callback + condition(u, t, integrator) = false # Never triggers + affect!(integrator) = nothing + cb = DiscreteCallback(condition, affect!) + + integrator_cb = DiffEqBase.init( + prob, LieTrotterGodunov((Euler(), Euler())), + dt = dt, callback = cb + ) + DiffEqBase.solve!(integrator_cb) + + # Solutions should be identical since callback never triggers + @test isapprox(integrator_cb.u, u_ref, rtol = 1e-10) +end diff --git a/test/runtests.jl b/test/runtests.jl index 661249b..871be2d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,3 +3,4 @@ using SafeTestsets @safetestset "Operator Splitting API" include("operator_splitting_api.jl") @safetestset "Aliasing" include("alias_u0.jl") +@safetestset "Callbacks" include("callbacks.jl")