Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 139 additions & 27 deletions src/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,13 @@ Contains a single callback whose `condition` is a continuous function. The callb
`affect!` satisfies the constraints (or else errors). It is not recommended that `NoInit()` is
used as that will lead to an unstable step following initialization. This warning can be
ignored for non-DAE ODEs.

# Extended help

- `saved_clock_partitions`: An iterable of clock partition indices to save after the callback triggers. MTK-only
API
"""
struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R, SCP} <:
AbstractContinuousCallback
condition::F1
affect!::F2
Expand All @@ -127,20 +132,20 @@ struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
reltol::T2
repeat_nudge::T3
initializealg::T4
saved_clock_partitions::SCP
function ContinuousCallback(condition::F1, affect!::F2, affect_neg!::F3,
initialize::F4, finalize::F5, idxs::I, rootfind,
interp_points, save_positions, dtrelax::R, abstol::T,
reltol::T2,
repeat_nudge::T3,
initializealg::T4 = nothing) where {F1, F2, F3, F4, F5, T, T2, T3, T4, I, R
reltol::T2, repeat_nudge::T3, initializealg::T4 = nothing,
saved_clock_partitions::SCP = ()) where {F1, F2, F3, F4, F5, T, T2, T3, T4, I, R, SCP
}
_condition = prepare_function(condition)
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R}(_condition,
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R, SCP}(_condition,
affect!, affect_neg!,
initialize, finalize, idxs, rootfind,
interp_points,
BitArray(collect(save_positions)),
dtrelax, abstol, reltol, repeat_nudge, initializealg)
dtrelax, abstol, reltol, repeat_nudge, initializealg, saved_clock_partitions)
end
end

Expand All @@ -154,12 +159,13 @@ function ContinuousCallback(condition, affect!, affect_neg!;
dtrelax = 1,
abstol = 10eps(), reltol = 0,
repeat_nudge = 1 // 100,
initializealg = nothing)
initializealg = nothing,
saved_clock_partitions = ())
ContinuousCallback(condition, affect!, affect_neg!, initialize, finalize,
idxs,
rootfind, interp_points,
save_positions,
dtrelax, abstol, reltol, repeat_nudge, initializealg)
dtrelax, abstol, reltol, repeat_nudge, initializealg, saved_clock_partitions)
end

function ContinuousCallback(condition, affect!;
Expand All @@ -172,11 +178,11 @@ function ContinuousCallback(condition, affect!;
interp_points = 10,
dtrelax = 1,
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
initializealg = nothing)
initializealg = nothing, saved_clock_partitions = ())
ContinuousCallback(condition, affect!, affect_neg!, initialize, finalize, idxs,
rootfind, interp_points,
collect(save_positions),
dtrelax, abstol, reltol, repeat_nudge, initializealg)
dtrelax, abstol, reltol, repeat_nudge, initializealg, saved_clock_partitions)
end

"""
Expand Down Expand Up @@ -219,8 +225,12 @@ multiple events.
- `len`: Number of callbacks chained. This is compulsory to be specified.

Rest of the arguments have the same meaning as in [`ContinuousCallback`](@ref).

# Extended help

- `saved_clock_partitions`: An iterable of `len` elements, where the `i`th element is an iterable of clock partition indices to save when the `i`th event triggers. MTK-only API.
"""
struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R, SCP} <:
AbstractContinuousCallback
condition::F1
affect!::F2
Expand All @@ -237,21 +247,24 @@ struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
reltol::T2
repeat_nudge::T3
initializealg::T4
saved_clock_partitions::SCP
function VectorContinuousCallback(
condition::F1, affect!::F2, affect_neg!::F3, len::Int,
initialize::F4, finalize::F5, idxs::I, rootfind,
interp_points, save_positions, dtrelax::R,
abstol::T, reltol::T2,
repeat_nudge::T3,
initializealg::T4 = nothing) where {F1, F2, F3, F4, F5, T, T2,
T3, T4, I, R}
abstol::T, reltol::T2, repeat_nudge::T3,
initializealg::T4 = nothing,
saved_clock_partitions::SCP = ()) where {F1, F2, F3, F4, F5, T, T2,
T3, T4, I, R, SCP}
_condition = prepare_function(condition)
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R}(_condition,
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R, SCP}(
_condition,
affect!, affect_neg!, len,
initialize, finalize, idxs, rootfind,
interp_points,
BitArray(collect(save_positions)),
dtrelax, abstol, reltol, repeat_nudge, initializealg)
dtrelax, abstol, reltol, repeat_nudge, initializealg,
saved_clock_partitions)
end
end

Expand All @@ -264,13 +277,13 @@ function VectorContinuousCallback(condition, affect!, affect_neg!, len;
interp_points = 10,
dtrelax = 1,
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
initializealg = nothing)
initializealg = nothing, saved_clock_partitions = ())
VectorContinuousCallback(condition, affect!, affect_neg!, len,
initialize, finalize,
idxs,
rootfind, interp_points,
save_positions, dtrelax,
abstol, reltol, repeat_nudge, initializealg)
abstol, reltol, repeat_nudge, initializealg, saved_clock_partitions)
end

function VectorContinuousCallback(condition, affect!, len;
Expand All @@ -283,12 +296,12 @@ function VectorContinuousCallback(condition, affect!, len;
interp_points = 10,
dtrelax = 1,
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
initializealg = nothing)
initializealg = nothing, saved_clock_partitions = ())
VectorContinuousCallback(condition, affect!, affect_neg!, len, initialize, finalize,
idxs,
rootfind, interp_points,
collect(save_positions),
dtrelax, abstol, reltol, repeat_nudge, initializealg)
dtrelax, abstol, reltol, repeat_nudge, initializealg, saved_clock_partitions)
end

"""
Expand Down Expand Up @@ -339,31 +352,39 @@ DiscreteCallback(condition, affect!;
`affect!` satisfies the constraints (or else errors). It is not recommended that `NoInit()` is
used as that will lead to an unstable step following initialization. This warning can be
ignored for non-DAE ODEs.

# Extended help

- `saved_clock_partitions`: An iterable of clock partition indices to save after the callback
triggers. MTK-only API
"""
struct DiscreteCallback{F1, F2, F3, F4, F5} <: AbstractDiscreteCallback
struct DiscreteCallback{F1, F2, F3, F4, F5, SCP} <: AbstractDiscreteCallback
condition::F1
affect!::F2
initialize::F3
finalize::F4
save_positions::BitArray{1}
initializealg::F5
saved_clock_partitions::SCP
function DiscreteCallback(condition::F1, affect!::F2,
initialize::F3, finalize::F4,
save_positions,
initializealg::F5 = nothing) where {F1, F2, F3, F4, F5}
initializealg::F5 = nothing,
saved_clock_partitions::SCP = ()) where {F1, F2, F3, F4, F5, SCP}
_condition = prepare_function(condition)
new{typeof(_condition), F2, F3, F4, F5}(_condition,
new{typeof(_condition), F2, F3, F4, F5, SCP}(_condition,
affect!, initialize, finalize,
BitArray(collect(save_positions)),
initializealg)
initializealg, saved_clock_partitions)
end
end
function DiscreteCallback(condition, affect!;
initialize = INITIALIZE_DEFAULT, finalize = FINALIZE_DEFAULT,
save_positions = (true, true),
initializealg = nothing)
initializealg = nothing, saved_clock_partitions = ())
DiscreteCallback(
condition, affect!, initialize, finalize, save_positions, initializealg)
condition, affect!, initialize, finalize, save_positions, initializealg,
saved_clock_partitions)
end

"""
Expand Down Expand Up @@ -420,3 +441,94 @@ end
split_callbacks((cs..., d.continuous_callbacks...), (ds..., d.discrete_callbacks...),
args...)
end

"""
$TYPEDSIGNATURES

Save the discrete variables associated with callback `cb` in `integrator`.

# Keyword arguments

- `skip_duplicates`: Skip saving variables that have already been saved at the current time.
"""
function save_discretes!(integrator::DEIntegrator, cb::Union{ContinuousCallback, DiscreteCallback}; skip_duplicates = false)
isempty(cb.saved_clock_partitions) && return
for idx in cb.saved_clock_partitions
save_discretes!(integrator, idx; skip_duplicates)
end
end

function save_discretes!(integrator::DEIntegrator, cb::VectorContinuousCallback; kw...)
isempty(cb.saved_clock_partitions) && return
for idx in eachindex(cb.saved_clock_partitions)
save_discretes!(integrator, cb, idx; skip_duplicates = true)
end
end

function save_discretes!(integrator::DEIntegrator, cb::VectorContinuousCallback, i; skip_duplicates = false)
isempty(cb.saved_clock_partitions) && return
for idx in cb.saved_clock_partitions[i]
save_discretes!(integrator, idx; skip_duplicates)
end
end

function _save_all_discretes!(integrator::DEIntegrator, cb::DECallback, cbs::DECallback...)
save_discretes!(integrator, cb; skip_duplicates = true)
_save_all_discretes!(integrator, cbs...)
end

_save_all_discretes!(::DEIntegrator) = nothing

function save_discretes!(integrator::DEIntegrator, cb::CallbackSet; kw...)
_save_all_discretes!(integrator, cb.continuous_callbacks..., cb.discrete_callbacks...)
end

"""
$TYPEDSIGNATURES

Save the discrete variables associated with callback `cb` in `integrator` if the finalizer
exists and `save_positions[2]` is `true`. Used to save the necessary values at the final
time of the simulation, after the finalizer has run.
"""
function save_final_discretes!(integrator::DEIntegrator, cb::Union{ContinuousCallback, VectorContinuousCallback, DiscreteCallback})
cb.finalize === FINALIZE_DEFAULT && return
cb.save_positions[2] || return
save_discretes!(integrator, cb; skip_duplicates = true)
end

function _save_all_final_discretes!(integrator::DEIntegrator, cb::DECallback, cbs::DECallback...)
save_final_discretes!(integrator, cb)
_save_all_final_discretes!(integrator, cbs...)
end

_save_all_final_discretes!(::DEIntegrator) = nothing

function save_final_discretes!(integrator::DEIntegrator, cb::CallbackSet; kw...)
_save_all_final_discretes!(integrator, cb.continuous_callbacks..., cb.discrete_callbacks...)
end

"""
$TYPEDSIGNATURES

Save the discrete variables associated with callback `cb` in `integrator` if
`save_positions[2]` is `true`.

# Keyword arguments

- `skip_duplicates`: Skip saving variables that have already been saved at the current time.
"""
function save_discretes_if_enabled!(integrator::DEIntegrator, cb::Union{ContinuousCallback, VectorContinuousCallback, DiscreteCallback}; skip_duplicates = false)
cb.save_positions[2] || return
save_discretes!(integrator, cb; skip_duplicates)
end

function _save_discretes_if_enabled!(integrator::DEIntegrator, cb::DECallback, cbs::DECallback...; kw...)
save_discretes_if_enabled!(integrator, cb; kw...)
_save_discretes_if_enabled!(integrator, cbs...; kw...)
end

_save_discretes_if_enabled!(::DEIntegrator; kw...) = nothing

function save_discretes_if_enabled!(integrator::DEIntegrator, cb::CallbackSet; kw...)
_save_discretes_if_enabled!(integrator, cb.continuous_callbacks..., cb.discrete_callbacks...; kw...)
end
18 changes: 12 additions & 6 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -422,28 +422,34 @@ end
Save the parameter timeseries with index `timeseries_idx`. Calls `get_saveable_values` to
get the values to save. If it returns `nothing`, then the save does not happen.
"""
function save_discretes!(integ::DEIntegrator, timeseries_idx)
function save_discretes!(integ::DEIntegrator, timeseries_idx; skip_duplicates = false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this undocumented arg?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a docstring. It's to avoid saving variables twice.

inner_sol = get_sol(integ)
vals = get_saveable_values(inner_sol, parameter_values(integ), timeseries_idx)
vals === nothing && return
save_discretes!(integ.sol, current_time(integ), vals, timeseries_idx)
save_discretes!(integ.sol, current_time(integ), vals, timeseries_idx; skip_duplicates)
end

save_discretes!(args...) = nothing

# public API, used by MTK
function save_discretes!(sol::AbstractODESolution, t, vals, timeseries_idx)
function save_discretes!(sol::AbstractODESolution, t, vals, timeseries_idx; skip_duplicates = false)
RecursiveArrayTools.has_discretes(sol) || return
disc = RecursiveArrayTools.get_discretes(sol)
_save_discretes_internal!(disc[timeseries_idx], t, vals)
_save_discretes_internal!(disc[timeseries_idx], t, vals; skip_duplicates)
end

function _save_discretes_internal!(A::AbstractDiffEqArray, t, vals)
function _save_discretes_internal!(A::AbstractDiffEqArray, t, vals; skip_duplicates = false)
if skip_duplicates && !isempty(A.t) && isequal(t, A.t[end])
return
end
push!(A.t, t)
push!(A.u, vals)
end

function _save_discretes_internal!(A::PeriodicDiffEqArray, t, vals)
function _save_discretes_internal!(A::PeriodicDiffEqArray, t, vals; skip_duplicates = false)
if skip_duplicates && !isempty(A.u) && isequal(A.t[length(A.u)], t)
return
end
idx = length(A.u) + 1
if A.t[idx] ≉ t
error("Tried to save periodic discrete value with timeseries $(A.t) at time $t")
Expand Down
Loading