diff --git a/src/callbacks.jl b/src/callbacks.jl index 40a2681f2..7a36e1d31 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -110,8 +110,13 @@ Contains a single callback whose `condition` is a continuous function. The callb `affect!` satisfies the constraints (or else errors). It is not recommended that `NoInit()` is used as that will lead to an unstable step following initialization. This warning can be ignored for non-DAE ODEs. + +# Extended help + +- `saved_clock_partitions`: An iterable of clock partition indices to save after the callback triggers. MTK-only + API """ -struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <: +struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R, SCP} <: AbstractContinuousCallback condition::F1 affect!::F2 @@ -127,20 +132,20 @@ struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <: reltol::T2 repeat_nudge::T3 initializealg::T4 + saved_clock_partitions::SCP function ContinuousCallback(condition::F1, affect!::F2, affect_neg!::F3, initialize::F4, finalize::F5, idxs::I, rootfind, interp_points, save_positions, dtrelax::R, abstol::T, - reltol::T2, - repeat_nudge::T3, - initializealg::T4 = nothing) where {F1, F2, F3, F4, F5, T, T2, T3, T4, I, R + reltol::T2, repeat_nudge::T3, initializealg::T4 = nothing, + saved_clock_partitions::SCP = ()) where {F1, F2, F3, F4, F5, T, T2, T3, T4, I, R, SCP } _condition = prepare_function(condition) - new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R}(_condition, + new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R, SCP}(_condition, affect!, affect_neg!, initialize, finalize, idxs, rootfind, interp_points, BitArray(collect(save_positions)), - dtrelax, abstol, reltol, repeat_nudge, initializealg) + dtrelax, abstol, reltol, repeat_nudge, initializealg, saved_clock_partitions) end end @@ -154,12 +159,13 @@ function ContinuousCallback(condition, affect!, affect_neg!; dtrelax = 1, abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100, - initializealg = nothing) + initializealg = nothing, + saved_clock_partitions = ()) ContinuousCallback(condition, affect!, affect_neg!, initialize, finalize, idxs, rootfind, interp_points, save_positions, - dtrelax, abstol, reltol, repeat_nudge, initializealg) + dtrelax, abstol, reltol, repeat_nudge, initializealg, saved_clock_partitions) end function ContinuousCallback(condition, affect!; @@ -172,11 +178,11 @@ function ContinuousCallback(condition, affect!; interp_points = 10, dtrelax = 1, abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100, - initializealg = nothing) + initializealg = nothing, saved_clock_partitions = ()) ContinuousCallback(condition, affect!, affect_neg!, initialize, finalize, idxs, rootfind, interp_points, collect(save_positions), - dtrelax, abstol, reltol, repeat_nudge, initializealg) + dtrelax, abstol, reltol, repeat_nudge, initializealg, saved_clock_partitions) end """ @@ -219,8 +225,12 @@ multiple events. - `len`: Number of callbacks chained. This is compulsory to be specified. Rest of the arguments have the same meaning as in [`ContinuousCallback`](@ref). + +# Extended help + +- `saved_clock_partitions`: An iterable of `len` elements, where the `i`th element is an iterable of clock partition indices to save when the `i`th event triggers. MTK-only API. """ -struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <: +struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R, SCP} <: AbstractContinuousCallback condition::F1 affect!::F2 @@ -237,21 +247,24 @@ struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <: reltol::T2 repeat_nudge::T3 initializealg::T4 + saved_clock_partitions::SCP function VectorContinuousCallback( condition::F1, affect!::F2, affect_neg!::F3, len::Int, initialize::F4, finalize::F5, idxs::I, rootfind, interp_points, save_positions, dtrelax::R, - abstol::T, reltol::T2, - repeat_nudge::T3, - initializealg::T4 = nothing) where {F1, F2, F3, F4, F5, T, T2, - T3, T4, I, R} + abstol::T, reltol::T2, repeat_nudge::T3, + initializealg::T4 = nothing, + saved_clock_partitions::SCP = ()) where {F1, F2, F3, F4, F5, T, T2, + T3, T4, I, R, SCP} _condition = prepare_function(condition) - new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R}(_condition, + new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R, SCP}( + _condition, affect!, affect_neg!, len, initialize, finalize, idxs, rootfind, interp_points, BitArray(collect(save_positions)), - dtrelax, abstol, reltol, repeat_nudge, initializealg) + dtrelax, abstol, reltol, repeat_nudge, initializealg, + saved_clock_partitions) end end @@ -264,13 +277,13 @@ function VectorContinuousCallback(condition, affect!, affect_neg!, len; interp_points = 10, dtrelax = 1, abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100, - initializealg = nothing) + initializealg = nothing, saved_clock_partitions = ()) VectorContinuousCallback(condition, affect!, affect_neg!, len, initialize, finalize, idxs, rootfind, interp_points, save_positions, dtrelax, - abstol, reltol, repeat_nudge, initializealg) + abstol, reltol, repeat_nudge, initializealg, saved_clock_partitions) end function VectorContinuousCallback(condition, affect!, len; @@ -283,12 +296,12 @@ function VectorContinuousCallback(condition, affect!, len; interp_points = 10, dtrelax = 1, abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100, - initializealg = nothing) + initializealg = nothing, saved_clock_partitions = ()) VectorContinuousCallback(condition, affect!, affect_neg!, len, initialize, finalize, idxs, rootfind, interp_points, collect(save_positions), - dtrelax, abstol, reltol, repeat_nudge, initializealg) + dtrelax, abstol, reltol, repeat_nudge, initializealg, saved_clock_partitions) end """ @@ -339,31 +352,39 @@ DiscreteCallback(condition, affect!; `affect!` satisfies the constraints (or else errors). It is not recommended that `NoInit()` is used as that will lead to an unstable step following initialization. This warning can be ignored for non-DAE ODEs. + +# Extended help + +- `saved_clock_partitions`: An iterable of clock partition indices to save after the callback + triggers. MTK-only API """ -struct DiscreteCallback{F1, F2, F3, F4, F5} <: AbstractDiscreteCallback +struct DiscreteCallback{F1, F2, F3, F4, F5, SCP} <: AbstractDiscreteCallback condition::F1 affect!::F2 initialize::F3 finalize::F4 save_positions::BitArray{1} initializealg::F5 + saved_clock_partitions::SCP function DiscreteCallback(condition::F1, affect!::F2, initialize::F3, finalize::F4, save_positions, - initializealg::F5 = nothing) where {F1, F2, F3, F4, F5} + initializealg::F5 = nothing, + saved_clock_partitions::SCP = ()) where {F1, F2, F3, F4, F5, SCP} _condition = prepare_function(condition) - new{typeof(_condition), F2, F3, F4, F5}(_condition, + new{typeof(_condition), F2, F3, F4, F5, SCP}(_condition, affect!, initialize, finalize, BitArray(collect(save_positions)), - initializealg) + initializealg, saved_clock_partitions) end end function DiscreteCallback(condition, affect!; initialize = INITIALIZE_DEFAULT, finalize = FINALIZE_DEFAULT, save_positions = (true, true), - initializealg = nothing) + initializealg = nothing, saved_clock_partitions = ()) DiscreteCallback( - condition, affect!, initialize, finalize, save_positions, initializealg) + condition, affect!, initialize, finalize, save_positions, initializealg, + saved_clock_partitions) end """ @@ -420,3 +441,94 @@ end split_callbacks((cs..., d.continuous_callbacks...), (ds..., d.discrete_callbacks...), args...) end + +""" + $TYPEDSIGNATURES + +Save the discrete variables associated with callback `cb` in `integrator`. + +# Keyword arguments + +- `skip_duplicates`: Skip saving variables that have already been saved at the current time. +""" +function save_discretes!(integrator::DEIntegrator, cb::Union{ContinuousCallback, DiscreteCallback}; skip_duplicates = false) + isempty(cb.saved_clock_partitions) && return + for idx in cb.saved_clock_partitions + save_discretes!(integrator, idx; skip_duplicates) + end +end + +function save_discretes!(integrator::DEIntegrator, cb::VectorContinuousCallback; kw...) + isempty(cb.saved_clock_partitions) && return + for idx in eachindex(cb.saved_clock_partitions) + save_discretes!(integrator, cb, idx; skip_duplicates = true) + end +end + +function save_discretes!(integrator::DEIntegrator, cb::VectorContinuousCallback, i; skip_duplicates = false) + isempty(cb.saved_clock_partitions) && return + for idx in cb.saved_clock_partitions[i] + save_discretes!(integrator, idx; skip_duplicates) + end +end + +function _save_all_discretes!(integrator::DEIntegrator, cb::DECallback, cbs::DECallback...) + save_discretes!(integrator, cb; skip_duplicates = true) + _save_all_discretes!(integrator, cbs...) +end + +_save_all_discretes!(::DEIntegrator) = nothing + +function save_discretes!(integrator::DEIntegrator, cb::CallbackSet; kw...) + _save_all_discretes!(integrator, cb.continuous_callbacks..., cb.discrete_callbacks...) +end + +""" + $TYPEDSIGNATURES + +Save the discrete variables associated with callback `cb` in `integrator` if the finalizer +exists and `save_positions[2]` is `true`. Used to save the necessary values at the final +time of the simulation, after the finalizer has run. +""" +function save_final_discretes!(integrator::DEIntegrator, cb::Union{ContinuousCallback, VectorContinuousCallback, DiscreteCallback}) + cb.finalize === FINALIZE_DEFAULT && return + cb.save_positions[2] || return + save_discretes!(integrator, cb; skip_duplicates = true) +end + +function _save_all_final_discretes!(integrator::DEIntegrator, cb::DECallback, cbs::DECallback...) + save_final_discretes!(integrator, cb) + _save_all_final_discretes!(integrator, cbs...) +end + +_save_all_final_discretes!(::DEIntegrator) = nothing + +function save_final_discretes!(integrator::DEIntegrator, cb::CallbackSet; kw...) + _save_all_final_discretes!(integrator, cb.continuous_callbacks..., cb.discrete_callbacks...) +end + +""" + $TYPEDSIGNATURES + +Save the discrete variables associated with callback `cb` in `integrator` if +`save_positions[2]` is `true`. + +# Keyword arguments + +- `skip_duplicates`: Skip saving variables that have already been saved at the current time. +""" +function save_discretes_if_enabled!(integrator::DEIntegrator, cb::Union{ContinuousCallback, VectorContinuousCallback, DiscreteCallback}; skip_duplicates = false) + cb.save_positions[2] || return + save_discretes!(integrator, cb; skip_duplicates) +end + +function _save_discretes_if_enabled!(integrator::DEIntegrator, cb::DECallback, cbs::DECallback...; kw...) + save_discretes_if_enabled!(integrator, cb; kw...) + _save_discretes_if_enabled!(integrator, cbs...; kw...) +end + +_save_discretes_if_enabled!(::DEIntegrator; kw...) = nothing + +function save_discretes_if_enabled!(integrator::DEIntegrator, cb::CallbackSet; kw...) + _save_discretes_if_enabled!(integrator, cb.continuous_callbacks..., cb.discrete_callbacks...; kw...) +end diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index c30e0a6d2..47d375627 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -422,28 +422,34 @@ end Save the parameter timeseries with index `timeseries_idx`. Calls `get_saveable_values` to get the values to save. If it returns `nothing`, then the save does not happen. """ -function save_discretes!(integ::DEIntegrator, timeseries_idx) +function save_discretes!(integ::DEIntegrator, timeseries_idx; skip_duplicates = false) inner_sol = get_sol(integ) vals = get_saveable_values(inner_sol, parameter_values(integ), timeseries_idx) vals === nothing && return - save_discretes!(integ.sol, current_time(integ), vals, timeseries_idx) + save_discretes!(integ.sol, current_time(integ), vals, timeseries_idx; skip_duplicates) end save_discretes!(args...) = nothing # public API, used by MTK -function save_discretes!(sol::AbstractODESolution, t, vals, timeseries_idx) +function save_discretes!(sol::AbstractODESolution, t, vals, timeseries_idx; skip_duplicates = false) RecursiveArrayTools.has_discretes(sol) || return disc = RecursiveArrayTools.get_discretes(sol) - _save_discretes_internal!(disc[timeseries_idx], t, vals) + _save_discretes_internal!(disc[timeseries_idx], t, vals; skip_duplicates) end -function _save_discretes_internal!(A::AbstractDiffEqArray, t, vals) +function _save_discretes_internal!(A::AbstractDiffEqArray, t, vals; skip_duplicates = false) + if skip_duplicates && !isempty(A.t) && isequal(t, A.t[end]) + return + end push!(A.t, t) push!(A.u, vals) end -function _save_discretes_internal!(A::PeriodicDiffEqArray, t, vals) +function _save_discretes_internal!(A::PeriodicDiffEqArray, t, vals; skip_duplicates = false) + if skip_duplicates && !isempty(A.u) && isequal(A.t[length(A.u)], t) + return + end idx = length(A.u) + 1 if A.t[idx] ≉ t error("Tried to save periodic discrete value with timeseries $(A.t) at time $t")