Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 283 additions & 3 deletions src/integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,276 @@ function (integrator::OperatorSplittingIntegrator)(tmp, t)
tmp, t, integrator.uprev, integrator.u, integrator.tprev, integrator.t)
end

"""
savevalues!(integrator::OperatorSplittingIntegrator, force_save=false)

Save the current solution values to the solution object. This is called by callbacks
when they trigger saves.
"""
function savevalues!(integrator::OperatorSplittingIntegrator, force_save = false)
saved = false
savedexactly = false

# Handle saveat times
tdir_t = tdir(integrator) * integrator.t
while !isempty(integrator.saveat) && first(integrator.saveat) <= tdir_t
saved = true
curt = tdir(integrator) * pop!(integrator.saveat)
if curt != integrator.t
# Interpolate to saveat time
val = copy(integrator.u)
integrator(val, curt)
push!(integrator.sol.t, curt)
push!(integrator.sol.u, val)
else
savedexactly = true
push!(integrator.sol.t, integrator.t)
push!(integrator.sol.u, copy(integrator.u))
end
end

# Force save if requested
if force_save && (isempty(integrator.sol.t) || integrator.sol.t[end] != integrator.t)
saved = true
savedexactly = true
push!(integrator.sol.t, integrator.t)
push!(integrator.sol.u, copy(integrator.u))
end

return saved, savedexactly
end

"""
handle_callbacks!(integrator::OperatorSplittingIntegrator)

Process callbacks after a step. This handles discrete callbacks at the outer
integrator level, checking conditions against the full solution state.

Note: Continuous callbacks are not fully supported for operator splitting methods
because root-finding across split subproblems is not well-defined. Users requiring
continuous callbacks should ensure their callback conditions are checked at
discrete time points.
"""
function handle_callbacks!(integrator::OperatorSplittingIntegrator)
discrete_callbacks = integrator.callback.discrete_callbacks
continuous_callbacks = integrator.callback.continuous_callbacks

discrete_modified = false
saved_in_cb = false

# Process discrete callbacks
if !(discrete_callbacks isa Tuple{})
discrete_modified, saved_in_cb = apply_discrete_callbacks!(integrator, discrete_callbacks)
end

# For continuous callbacks, we do a simplified check at the endpoint only
# Full root-finding is not supported for operator splitting
if !(continuous_callbacks isa Tuple{})
cont_modified, cont_saved = apply_continuous_callbacks_simple!(integrator, continuous_callbacks)
discrete_modified = discrete_modified || cont_modified
saved_in_cb = saved_in_cb || cont_saved
end

# Save if no callback saved
if !saved_in_cb
savevalues!(integrator)
end

integrator.u_modified = discrete_modified
end

"""
apply_discrete_callbacks!(integrator, callbacks)

Apply discrete callbacks to the integrator. Returns (modified, saved_in_cb) tuple.
"""
function apply_discrete_callbacks!(integrator::OperatorSplittingIntegrator, callbacks::Tuple)
modified = false
saved = false
for callback in callbacks
cb_modified, cb_saved = apply_discrete_callback!(integrator, callback)
modified = modified || cb_modified
saved = saved || cb_saved
end
return modified, saved
end

"""
apply_discrete_callback!(integrator, callback)

Apply a single discrete callback. Returns (modified, saved_in_cb) tuple.
"""
function apply_discrete_callback!(integrator::OperatorSplittingIntegrator, callback::SciMLBase.DiscreteCallback)
saved_in_cb = false
if callback.condition(integrator.u, integrator.t, integrator)
# Handle saveat
_, savedexactly = savevalues!(integrator)
saved_in_cb = true

if callback.save_positions[1]
# Save before affect if requested and not already saved
savedexactly || savevalues!(integrator, true)
end

integrator.u_modified = true
callback.affect!(integrator)

if integrator.u_modified
# Sync modified u to uprev so next step starts correctly
integrator.uprev .= integrator.u
end

if callback.save_positions[2]
savevalues!(integrator, true)
saved_in_cb = true
end

return integrator.u_modified, saved_in_cb
end
return false, saved_in_cb
end

"""
apply_continuous_callbacks_simple!(integrator, callbacks)

Apply continuous callbacks with a simplified endpoint-only check.
Full root-finding is not supported for operator splitting methods.
Returns (modified, saved_in_cb) tuple.
"""
function apply_continuous_callbacks_simple!(integrator::OperatorSplittingIntegrator, callbacks::Tuple)
modified = false
saved = false
for callback in callbacks
cb_modified, cb_saved = apply_continuous_callback_simple!(integrator, callback)
modified = modified || cb_modified
saved = saved || cb_saved
end
return modified, saved
end

"""
apply_continuous_callback_simple!(integrator, callback)

Apply a continuous callback with simplified endpoint checking.
This checks for sign changes between tprev and t without root-finding.
"""
function apply_continuous_callback_simple!(integrator::OperatorSplittingIntegrator, callback::SciMLBase.ContinuousCallback)
# Evaluate condition at previous and current time
if callback.idxs === nothing
prev_condition = callback.condition(integrator.uprev, integrator.tprev, integrator)
curr_condition = callback.condition(integrator.u, integrator.t, integrator)
else
prev_condition = callback.condition(
@view(integrator.uprev[callback.idxs]), integrator.tprev, integrator)
curr_condition = callback.condition(
@view(integrator.u[callback.idxs]), integrator.t, integrator)
end

prev_sign = sign(prev_condition)
curr_sign = sign(curr_condition)

# Check for sign change (zero crossing)
if prev_sign * curr_sign <= 0 && prev_sign != 0
saved_in_cb = false

# Handle saveat
_, savedexactly = savevalues!(integrator)
saved_in_cb = true

if callback.save_positions[1]
savedexactly || savevalues!(integrator, true)
end

integrator.u_modified = true

# Apply the appropriate affect function based on crossing direction
if prev_sign < 0 && callback.affect! !== nothing
callback.affect!(integrator)
elseif prev_sign > 0 && callback.affect_neg! !== nothing
callback.affect_neg!(integrator)
else
integrator.u_modified = false
end

if integrator.u_modified
# Sync modified u to uprev
integrator.uprev .= integrator.u
end

if callback.save_positions[2]
savevalues!(integrator, true)
saved_in_cb = true
end

return integrator.u_modified, saved_in_cb
end

return false, false
end

function apply_continuous_callback_simple!(integrator::OperatorSplittingIntegrator, callback::SciMLBase.VectorContinuousCallback)
# For VectorContinuousCallback, we need to check each component
prev_conditions = similar(integrator.u, callback.len)
curr_conditions = similar(integrator.u, callback.len)

if callback.idxs === nothing
callback.condition(prev_conditions, integrator.uprev, integrator.tprev, integrator)
callback.condition(curr_conditions, integrator.u, integrator.t, integrator)
else
callback.condition(prev_conditions,
@view(integrator.uprev[callback.idxs]), integrator.tprev, integrator)
callback.condition(curr_conditions,
@view(integrator.u[callback.idxs]), integrator.t, integrator)
end

# Find first event (sign change)
event_idx = 0
prev_sign_val = 0.0
for i in 1:callback.len
prev_sign = sign(prev_conditions[i])
curr_sign = sign(curr_conditions[i])
if prev_sign * curr_sign <= 0 && prev_sign != 0
event_idx = i
prev_sign_val = prev_sign
break
end
end

if event_idx > 0
saved_in_cb = false

_, savedexactly = savevalues!(integrator)
saved_in_cb = true

if callback.save_positions[1]
savedexactly || savevalues!(integrator, true)
end

integrator.u_modified = true

if prev_sign_val < 0 && callback.affect! !== nothing
callback.affect!(integrator, event_idx)
elseif prev_sign_val > 0 && callback.affect_neg! !== nothing
callback.affect_neg!(integrator, event_idx)
else
integrator.u_modified = false
end

if integrator.u_modified
integrator.uprev .= integrator.u
end

if callback.save_positions[2]
savevalues!(integrator, true)
saved_in_cb = true
end

return integrator.u_modified, saved_in_cb
end

return false, false
end

"""
stepsize_controller!(::OperatorSplittingIntegrator)

Expand Down Expand Up @@ -391,6 +661,12 @@ end
function __step!(integrator)
(; dtchangeable, tstops, _dt) = integrator

# If the previous step modified u (via callback), update uprev
if integrator.u_modified
integrator.uprev .= integrator.u
integrator.u_modified = false
end

# update dt before incrementing u; if dt is changeable and there is
# a tstop within dt, reduce dt to tstop - t
integrator.dt = !isempty(tstops) && dtchangeable ?
Expand All @@ -417,6 +693,9 @@ function __step!(integrator)

step_accept_controller!(integrator)

# Handle callbacks after the step
handle_callbacks!(integrator)

# remove tstops that were just reached
while !isempty(tstops) && reached_tstop(integrator, first(tstops))
pop!(tstops)
Expand Down Expand Up @@ -455,9 +734,10 @@ function DiffEqBase.add_saveat!(integrator::OperatorSplittingIntegrator, t)
push!(integrator.saveat, t)
end

# not sure what this should do?
# defined as default initialize: https://github.com/SciML/DiffEqBase.jl/blob/master/src/callbacks.jl#L3
DiffEqBase.u_modified!(i::OperatorSplittingIntegrator, bool) = nothing
# Set u_modified flag to track when callbacks modify the solution
function DiffEqBase.u_modified!(integrator::OperatorSplittingIntegrator, bool)
integrator.u_modified = bool
end

function synchronize_subintegrator_tree!(integrator::OperatorSplittingIntegrator)
synchronize_subintegrator!(integrator.subintegrator_tree, integrator)
Expand Down
Loading