Skip to content

Commit d314b77

Browse files
Merge pull request #1116 from AayushSabharwal/as/callback-save
feat: add discrete saving to callback structs
2 parents cc25931 + 2e4aaad commit d314b77

File tree

2 files changed

+151
-33
lines changed

2 files changed

+151
-33
lines changed

src/callbacks.jl

Lines changed: 139 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,13 @@ Contains a single callback whose `condition` is a continuous function. The callb
110110
`affect!` satisfies the constraints (or else errors). It is not recommended that `NoInit()` is
111111
used as that will lead to an unstable step following initialization. This warning can be
112112
ignored for non-DAE ODEs.
113+
114+
# Extended help
115+
116+
- `saved_clock_partitions`: An iterable of clock partition indices to save after the callback triggers. MTK-only
117+
API
113118
"""
114-
struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
119+
struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R, SCP} <:
115120
AbstractContinuousCallback
116121
condition::F1
117122
affect!::F2
@@ -127,20 +132,20 @@ struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
127132
reltol::T2
128133
repeat_nudge::T3
129134
initializealg::T4
135+
saved_clock_partitions::SCP
130136
function ContinuousCallback(condition::F1, affect!::F2, affect_neg!::F3,
131137
initialize::F4, finalize::F5, idxs::I, rootfind,
132138
interp_points, save_positions, dtrelax::R, abstol::T,
133-
reltol::T2,
134-
repeat_nudge::T3,
135-
initializealg::T4 = nothing) where {F1, F2, F3, F4, F5, T, T2, T3, T4, I, R
139+
reltol::T2, repeat_nudge::T3, initializealg::T4 = nothing,
140+
saved_clock_partitions::SCP = ()) where {F1, F2, F3, F4, F5, T, T2, T3, T4, I, R, SCP
136141
}
137142
_condition = prepare_function(condition)
138-
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R}(_condition,
143+
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R, SCP}(_condition,
139144
affect!, affect_neg!,
140145
initialize, finalize, idxs, rootfind,
141146
interp_points,
142147
BitArray(collect(save_positions)),
143-
dtrelax, abstol, reltol, repeat_nudge, initializealg)
148+
dtrelax, abstol, reltol, repeat_nudge, initializealg, saved_clock_partitions)
144149
end
145150
end
146151

@@ -154,12 +159,13 @@ function ContinuousCallback(condition, affect!, affect_neg!;
154159
dtrelax = 1,
155160
abstol = 10eps(), reltol = 0,
156161
repeat_nudge = 1 // 100,
157-
initializealg = nothing)
162+
initializealg = nothing,
163+
saved_clock_partitions = ())
158164
ContinuousCallback(condition, affect!, affect_neg!, initialize, finalize,
159165
idxs,
160166
rootfind, interp_points,
161167
save_positions,
162-
dtrelax, abstol, reltol, repeat_nudge, initializealg)
168+
dtrelax, abstol, reltol, repeat_nudge, initializealg, saved_clock_partitions)
163169
end
164170

165171
function ContinuousCallback(condition, affect!;
@@ -172,11 +178,11 @@ function ContinuousCallback(condition, affect!;
172178
interp_points = 10,
173179
dtrelax = 1,
174180
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
175-
initializealg = nothing)
181+
initializealg = nothing, saved_clock_partitions = ())
176182
ContinuousCallback(condition, affect!, affect_neg!, initialize, finalize, idxs,
177183
rootfind, interp_points,
178184
collect(save_positions),
179-
dtrelax, abstol, reltol, repeat_nudge, initializealg)
185+
dtrelax, abstol, reltol, repeat_nudge, initializealg, saved_clock_partitions)
180186
end
181187

182188
"""
@@ -219,8 +225,12 @@ multiple events.
219225
- `len`: Number of callbacks chained. This is compulsory to be specified.
220226
221227
Rest of the arguments have the same meaning as in [`ContinuousCallback`](@ref).
228+
229+
# Extended help
230+
231+
- `saved_clock_partitions`: An iterable of `len` elements, where the `i`th element is an iterable of clock partition indices to save when the `i`th event triggers. MTK-only API.
222232
"""
223-
struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
233+
struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R, SCP} <:
224234
AbstractContinuousCallback
225235
condition::F1
226236
affect!::F2
@@ -237,21 +247,24 @@ struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
237247
reltol::T2
238248
repeat_nudge::T3
239249
initializealg::T4
250+
saved_clock_partitions::SCP
240251
function VectorContinuousCallback(
241252
condition::F1, affect!::F2, affect_neg!::F3, len::Int,
242253
initialize::F4, finalize::F5, idxs::I, rootfind,
243254
interp_points, save_positions, dtrelax::R,
244-
abstol::T, reltol::T2,
245-
repeat_nudge::T3,
246-
initializealg::T4 = nothing) where {F1, F2, F3, F4, F5, T, T2,
247-
T3, T4, I, R}
255+
abstol::T, reltol::T2, repeat_nudge::T3,
256+
initializealg::T4 = nothing,
257+
saved_clock_partitions::SCP = ()) where {F1, F2, F3, F4, F5, T, T2,
258+
T3, T4, I, R, SCP}
248259
_condition = prepare_function(condition)
249-
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R}(_condition,
260+
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R, SCP}(
261+
_condition,
250262
affect!, affect_neg!, len,
251263
initialize, finalize, idxs, rootfind,
252264
interp_points,
253265
BitArray(collect(save_positions)),
254-
dtrelax, abstol, reltol, repeat_nudge, initializealg)
266+
dtrelax, abstol, reltol, repeat_nudge, initializealg,
267+
saved_clock_partitions)
255268
end
256269
end
257270

@@ -264,13 +277,13 @@ function VectorContinuousCallback(condition, affect!, affect_neg!, len;
264277
interp_points = 10,
265278
dtrelax = 1,
266279
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
267-
initializealg = nothing)
280+
initializealg = nothing, saved_clock_partitions = ())
268281
VectorContinuousCallback(condition, affect!, affect_neg!, len,
269282
initialize, finalize,
270283
idxs,
271284
rootfind, interp_points,
272285
save_positions, dtrelax,
273-
abstol, reltol, repeat_nudge, initializealg)
286+
abstol, reltol, repeat_nudge, initializealg, saved_clock_partitions)
274287
end
275288

276289
function VectorContinuousCallback(condition, affect!, len;
@@ -283,12 +296,12 @@ function VectorContinuousCallback(condition, affect!, len;
283296
interp_points = 10,
284297
dtrelax = 1,
285298
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
286-
initializealg = nothing)
299+
initializealg = nothing, saved_clock_partitions = ())
287300
VectorContinuousCallback(condition, affect!, affect_neg!, len, initialize, finalize,
288301
idxs,
289302
rootfind, interp_points,
290303
collect(save_positions),
291-
dtrelax, abstol, reltol, repeat_nudge, initializealg)
304+
dtrelax, abstol, reltol, repeat_nudge, initializealg, saved_clock_partitions)
292305
end
293306

294307
"""
@@ -339,31 +352,39 @@ DiscreteCallback(condition, affect!;
339352
`affect!` satisfies the constraints (or else errors). It is not recommended that `NoInit()` is
340353
used as that will lead to an unstable step following initialization. This warning can be
341354
ignored for non-DAE ODEs.
355+
356+
# Extended help
357+
358+
- `saved_clock_partitions`: An iterable of clock partition indices to save after the callback
359+
triggers. MTK-only API
342360
"""
343-
struct DiscreteCallback{F1, F2, F3, F4, F5} <: AbstractDiscreteCallback
361+
struct DiscreteCallback{F1, F2, F3, F4, F5, SCP} <: AbstractDiscreteCallback
344362
condition::F1
345363
affect!::F2
346364
initialize::F3
347365
finalize::F4
348366
save_positions::BitArray{1}
349367
initializealg::F5
368+
saved_clock_partitions::SCP
350369
function DiscreteCallback(condition::F1, affect!::F2,
351370
initialize::F3, finalize::F4,
352371
save_positions,
353-
initializealg::F5 = nothing) where {F1, F2, F3, F4, F5}
372+
initializealg::F5 = nothing,
373+
saved_clock_partitions::SCP = ()) where {F1, F2, F3, F4, F5, SCP}
354374
_condition = prepare_function(condition)
355-
new{typeof(_condition), F2, F3, F4, F5}(_condition,
375+
new{typeof(_condition), F2, F3, F4, F5, SCP}(_condition,
356376
affect!, initialize, finalize,
357377
BitArray(collect(save_positions)),
358-
initializealg)
378+
initializealg, saved_clock_partitions)
359379
end
360380
end
361381
function DiscreteCallback(condition, affect!;
362382
initialize = INITIALIZE_DEFAULT, finalize = FINALIZE_DEFAULT,
363383
save_positions = (true, true),
364-
initializealg = nothing)
384+
initializealg = nothing, saved_clock_partitions = ())
365385
DiscreteCallback(
366-
condition, affect!, initialize, finalize, save_positions, initializealg)
386+
condition, affect!, initialize, finalize, save_positions, initializealg,
387+
saved_clock_partitions)
367388
end
368389

369390
"""
@@ -420,3 +441,94 @@ end
420441
split_callbacks((cs..., d.continuous_callbacks...), (ds..., d.discrete_callbacks...),
421442
args...)
422443
end
444+
445+
"""
446+
$TYPEDSIGNATURES
447+
448+
Save the discrete variables associated with callback `cb` in `integrator`.
449+
450+
# Keyword arguments
451+
452+
- `skip_duplicates`: Skip saving variables that have already been saved at the current time.
453+
"""
454+
function save_discretes!(integrator::DEIntegrator, cb::Union{ContinuousCallback, DiscreteCallback}; skip_duplicates = false)
455+
isempty(cb.saved_clock_partitions) && return
456+
for idx in cb.saved_clock_partitions
457+
save_discretes!(integrator, idx; skip_duplicates)
458+
end
459+
end
460+
461+
function save_discretes!(integrator::DEIntegrator, cb::VectorContinuousCallback; kw...)
462+
isempty(cb.saved_clock_partitions) && return
463+
for idx in eachindex(cb.saved_clock_partitions)
464+
save_discretes!(integrator, cb, idx; skip_duplicates = true)
465+
end
466+
end
467+
468+
function save_discretes!(integrator::DEIntegrator, cb::VectorContinuousCallback, i; skip_duplicates = false)
469+
isempty(cb.saved_clock_partitions) && return
470+
for idx in cb.saved_clock_partitions[i]
471+
save_discretes!(integrator, idx; skip_duplicates)
472+
end
473+
end
474+
475+
function _save_all_discretes!(integrator::DEIntegrator, cb::DECallback, cbs::DECallback...)
476+
save_discretes!(integrator, cb; skip_duplicates = true)
477+
_save_all_discretes!(integrator, cbs...)
478+
end
479+
480+
_save_all_discretes!(::DEIntegrator) = nothing
481+
482+
function save_discretes!(integrator::DEIntegrator, cb::CallbackSet; kw...)
483+
_save_all_discretes!(integrator, cb.continuous_callbacks..., cb.discrete_callbacks...)
484+
end
485+
486+
"""
487+
$TYPEDSIGNATURES
488+
489+
Save the discrete variables associated with callback `cb` in `integrator` if the finalizer
490+
exists and `save_positions[2]` is `true`. Used to save the necessary values at the final
491+
time of the simulation, after the finalizer has run.
492+
"""
493+
function save_final_discretes!(integrator::DEIntegrator, cb::Union{ContinuousCallback, VectorContinuousCallback, DiscreteCallback})
494+
cb.finalize === FINALIZE_DEFAULT && return
495+
cb.save_positions[2] || return
496+
save_discretes!(integrator, cb; skip_duplicates = true)
497+
end
498+
499+
function _save_all_final_discretes!(integrator::DEIntegrator, cb::DECallback, cbs::DECallback...)
500+
save_final_discretes!(integrator, cb)
501+
_save_all_final_discretes!(integrator, cbs...)
502+
end
503+
504+
_save_all_final_discretes!(::DEIntegrator) = nothing
505+
506+
function save_final_discretes!(integrator::DEIntegrator, cb::CallbackSet; kw...)
507+
_save_all_final_discretes!(integrator, cb.continuous_callbacks..., cb.discrete_callbacks...)
508+
end
509+
510+
"""
511+
$TYPEDSIGNATURES
512+
513+
Save the discrete variables associated with callback `cb` in `integrator` if
514+
`save_positions[2]` is `true`.
515+
516+
# Keyword arguments
517+
518+
- `skip_duplicates`: Skip saving variables that have already been saved at the current time.
519+
"""
520+
function save_discretes_if_enabled!(integrator::DEIntegrator, cb::Union{ContinuousCallback, VectorContinuousCallback, DiscreteCallback}; skip_duplicates = false)
521+
cb.save_positions[2] || return
522+
save_discretes!(integrator, cb; skip_duplicates)
523+
end
524+
525+
function _save_discretes_if_enabled!(integrator::DEIntegrator, cb::DECallback, cbs::DECallback...; kw...)
526+
save_discretes_if_enabled!(integrator, cb; kw...)
527+
_save_discretes_if_enabled!(integrator, cbs...; kw...)
528+
end
529+
530+
_save_discretes_if_enabled!(::DEIntegrator; kw...) = nothing
531+
532+
function save_discretes_if_enabled!(integrator::DEIntegrator, cb::CallbackSet; kw...)
533+
_save_discretes_if_enabled!(integrator, cb.continuous_callbacks..., cb.discrete_callbacks...; kw...)
534+
end

src/solutions/ode_solutions.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -422,28 +422,34 @@ end
422422
Save the parameter timeseries with index `timeseries_idx`. Calls `get_saveable_values` to
423423
get the values to save. If it returns `nothing`, then the save does not happen.
424424
"""
425-
function save_discretes!(integ::DEIntegrator, timeseries_idx)
425+
function save_discretes!(integ::DEIntegrator, timeseries_idx; skip_duplicates = false)
426426
inner_sol = get_sol(integ)
427427
vals = get_saveable_values(inner_sol, parameter_values(integ), timeseries_idx)
428428
vals === nothing && return
429-
save_discretes!(integ.sol, current_time(integ), vals, timeseries_idx)
429+
save_discretes!(integ.sol, current_time(integ), vals, timeseries_idx; skip_duplicates)
430430
end
431431

432432
save_discretes!(args...) = nothing
433433

434434
# public API, used by MTK
435-
function save_discretes!(sol::AbstractODESolution, t, vals, timeseries_idx)
435+
function save_discretes!(sol::AbstractODESolution, t, vals, timeseries_idx; skip_duplicates = false)
436436
RecursiveArrayTools.has_discretes(sol) || return
437437
disc = RecursiveArrayTools.get_discretes(sol)
438-
_save_discretes_internal!(disc[timeseries_idx], t, vals)
438+
_save_discretes_internal!(disc[timeseries_idx], t, vals; skip_duplicates)
439439
end
440440

441-
function _save_discretes_internal!(A::AbstractDiffEqArray, t, vals)
441+
function _save_discretes_internal!(A::AbstractDiffEqArray, t, vals; skip_duplicates = false)
442+
if skip_duplicates && !isempty(A.t) && isequal(t, A.t[end])
443+
return
444+
end
442445
push!(A.t, t)
443446
push!(A.u, vals)
444447
end
445448

446-
function _save_discretes_internal!(A::PeriodicDiffEqArray, t, vals)
449+
function _save_discretes_internal!(A::PeriodicDiffEqArray, t, vals; skip_duplicates = false)
450+
if skip_duplicates && !isempty(A.u) && isequal(A.t[length(A.u)], t)
451+
return
452+
end
447453
idx = length(A.u) + 1
448454
if A.t[idx] t
449455
error("Tried to save periodic discrete value with timeseries $(A.t) at time $t")

0 commit comments

Comments
 (0)