Skip to content

Commit ff0ac79

Browse files
feat: add discrete saving to callback structs
1 parent a6feb53 commit ff0ac79

File tree

1 file changed

+75
-27
lines changed

1 file changed

+75
-27
lines changed

src/callbacks.jl

Lines changed: 75 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,13 @@ Contains a single callback whose `condition` is a continuous function. The callb
110110
`affect!` satisfies the constraints (or else errors). It is not recommended that `NoInit()` is
111111
used as that will lead to an unstable step following initialization. This warning can be
112112
ignored for non-DAE ODEs.
113+
114+
# Extended help
115+
116+
- `discrete_save_idxs`: An iterable of timeseries indexes to save after the callback triggers. MTK-only
117+
API
113118
"""
114-
struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
119+
struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R, DSI} <:
115120
AbstractContinuousCallback
116121
condition::F1
117122
affect!::F2
@@ -127,20 +132,20 @@ struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
127132
reltol::T2
128133
repeat_nudge::T3
129134
initializealg::T4
135+
discrete_save_idxs::DSI
130136
function ContinuousCallback(condition::F1, affect!::F2, affect_neg!::F3,
131137
initialize::F4, finalize::F5, idxs::I, rootfind,
132138
interp_points, save_positions, dtrelax::R, abstol::T,
133-
reltol::T2,
134-
repeat_nudge::T3,
135-
initializealg::T4 = nothing) where {F1, F2, F3, F4, F5, T, T2, T3, T4, I, R
139+
reltol::T2, repeat_nudge::T3, initializealg::T4 = nothing,
140+
discrete_save_idxs::DSI = ()) where {F1, F2, F3, F4, F5, T, T2, T3, T4, I, R, DSI
136141
}
137142
_condition = prepare_function(condition)
138-
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R}(_condition,
143+
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R, DSI}(_condition,
139144
affect!, affect_neg!,
140145
initialize, finalize, idxs, rootfind,
141146
interp_points,
142147
BitArray(collect(save_positions)),
143-
dtrelax, abstol, reltol, repeat_nudge, initializealg)
148+
dtrelax, abstol, reltol, repeat_nudge, initializealg, discrete_save_idxs)
144149
end
145150
end
146151

@@ -154,12 +159,13 @@ function ContinuousCallback(condition, affect!, affect_neg!;
154159
dtrelax = 1,
155160
abstol = 10eps(), reltol = 0,
156161
repeat_nudge = 1 // 100,
157-
initializealg = nothing)
162+
initializealg = nothing,
163+
discrete_save_idxs = ())
158164
ContinuousCallback(condition, affect!, affect_neg!, initialize, finalize,
159165
idxs,
160166
rootfind, interp_points,
161167
save_positions,
162-
dtrelax, abstol, reltol, repeat_nudge, initializealg)
168+
dtrelax, abstol, reltol, repeat_nudge, initializealg, discrete_save_idxs)
163169
end
164170

165171
function ContinuousCallback(condition, affect!;
@@ -172,11 +178,11 @@ function ContinuousCallback(condition, affect!;
172178
interp_points = 10,
173179
dtrelax = 1,
174180
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
175-
initializealg = nothing)
181+
initializealg = nothing, discrete_save_idxs = ())
176182
ContinuousCallback(condition, affect!, affect_neg!, initialize, finalize, idxs,
177183
rootfind, interp_points,
178184
collect(save_positions),
179-
dtrelax, abstol, reltol, repeat_nudge, initializealg)
185+
dtrelax, abstol, reltol, repeat_nudge, initializealg, discrete_save_idxs)
180186
end
181187

182188
"""
@@ -219,8 +225,16 @@ multiple events.
219225
- `len`: Number of callbacks chained. This is compulsory to be specified.
220226
221227
Rest of the arguments have the same meaning as in [`ContinuousCallback`](@ref).
228+
229+
# Extended help
230+
231+
- `discrete_save_idxs`: An iterable of `len` elements, where the `i`th element is an iterable of timeseries
232+
indexes to save when the `i`th event triggers. MTK-only API.
233+
- `all_save_idxs`: An iterable of all unique timeseries indexes in `discrete_save_idxs`. Used to save after
234+
`initialize` and `finalize`. This avoids saving twice if the same timeseries index is saved by two
235+
events.
222236
"""
223-
struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
237+
struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R, DSI, ASI} <:
224238
AbstractContinuousCallback
225239
condition::F1
226240
affect!::F2
@@ -237,21 +251,25 @@ struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
237251
reltol::T2
238252
repeat_nudge::T3
239253
initializealg::T4
254+
discrete_save_idxs::DSI
255+
all_save_idxs::ASI
240256
function VectorContinuousCallback(
241257
condition::F1, affect!::F2, affect_neg!::F3, len::Int,
242258
initialize::F4, finalize::F5, idxs::I, rootfind,
243259
interp_points, save_positions, dtrelax::R,
244-
abstol::T, reltol::T2,
245-
repeat_nudge::T3,
246-
initializealg::T4 = nothing) where {F1, F2, F3, F4, F5, T, T2,
247-
T3, T4, I, R}
260+
abstol::T, reltol::T2, repeat_nudge::T3,
261+
initializealg::T4 = nothing, discrete_save_idxs::DSI = (),
262+
all_save_idxs::ASI = ()) where {F1, F2, F3, F4, F5, T, T2,
263+
T3, T4, I, R, DSI, ASI}
248264
_condition = prepare_function(condition)
249-
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R}(_condition,
265+
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R, DSI, ASI}(
266+
_condition,
250267
affect!, affect_neg!, len,
251268
initialize, finalize, idxs, rootfind,
252269
interp_points,
253270
BitArray(collect(save_positions)),
254-
dtrelax, abstol, reltol, repeat_nudge, initializealg)
271+
dtrelax, abstol, reltol, repeat_nudge, initializealg,
272+
discrete_save_idxs, all_save_idxs)
255273
end
256274
end
257275

@@ -264,13 +282,15 @@ function VectorContinuousCallback(condition, affect!, affect_neg!, len;
264282
interp_points = 10,
265283
dtrelax = 1,
266284
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
267-
initializealg = nothing)
285+
initializealg = nothing, discrete_save_idxs = (),
286+
all_save_idxs = ())
268287
VectorContinuousCallback(condition, affect!, affect_neg!, len,
269288
initialize, finalize,
270289
idxs,
271290
rootfind, interp_points,
272291
save_positions, dtrelax,
273-
abstol, reltol, repeat_nudge, initializealg)
292+
abstol, reltol, repeat_nudge, initializealg, discrete_save_idxs,
293+
all_save_idxs)
274294
end
275295

276296
function VectorContinuousCallback(condition, affect!, len;
@@ -283,12 +303,14 @@ function VectorContinuousCallback(condition, affect!, len;
283303
interp_points = 10,
284304
dtrelax = 1,
285305
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
286-
initializealg = nothing)
306+
initializealg = nothing, discrete_save_idxs = (),
307+
all_save_idxs = ())
287308
VectorContinuousCallback(condition, affect!, affect_neg!, len, initialize, finalize,
288309
idxs,
289310
rootfind, interp_points,
290311
collect(save_positions),
291-
dtrelax, abstol, reltol, repeat_nudge, initializealg)
312+
dtrelax, abstol, reltol, repeat_nudge, initializealg, discrete_save_idxs,
313+
all_save_idxs)
292314
end
293315

294316
"""
@@ -339,31 +361,39 @@ DiscreteCallback(condition, affect!;
339361
`affect!` satisfies the constraints (or else errors). It is not recommended that `NoInit()` is
340362
used as that will lead to an unstable step following initialization. This warning can be
341363
ignored for non-DAE ODEs.
364+
365+
# Extended help
366+
367+
- `discrete_save_idxs`: An iterable of timeseries indexes to save after the callback triggers. MTK-only
368+
API
342369
"""
343-
struct DiscreteCallback{F1, F2, F3, F4, F5} <: AbstractDiscreteCallback
370+
struct DiscreteCallback{F1, F2, F3, F4, F5, DSI} <: AbstractDiscreteCallback
344371
condition::F1
345372
affect!::F2
346373
initialize::F3
347374
finalize::F4
348375
save_positions::BitArray{1}
349376
initializealg::F5
377+
discrete_save_idxs::DSI
350378
function DiscreteCallback(condition::F1, affect!::F2,
351379
initialize::F3, finalize::F4,
352380
save_positions,
353-
initializealg::F5 = nothing) where {F1, F2, F3, F4, F5}
381+
initializealg::F5 = nothing,
382+
discrete_save_idxs::DSI = ()) where {F1, F2, F3, F4, F5, DSI}
354383
_condition = prepare_function(condition)
355-
new{typeof(_condition), F2, F3, F4, F5}(_condition,
384+
new{typeof(_condition), F2, F3, F4, F5, DSI}(_condition,
356385
affect!, initialize, finalize,
357386
BitArray(collect(save_positions)),
358-
initializealg)
387+
initializealg, discrete_save_idxs)
359388
end
360389
end
361390
function DiscreteCallback(condition, affect!;
362391
initialize = INITIALIZE_DEFAULT, finalize = FINALIZE_DEFAULT,
363392
save_positions = (true, true),
364-
initializealg = nothing)
393+
initializealg = nothing, discrete_save_idxs = ())
365394
DiscreteCallback(
366-
condition, affect!, initialize, finalize, save_positions, initializealg)
395+
condition, affect!, initialize, finalize, save_positions, initializealg,
396+
discrete_save_idxs)
367397
end
368398

369399
"""
@@ -420,3 +450,21 @@ end
420450
split_callbacks((cs..., d.continuous_callbacks...), (ds..., d.discrete_callbacks...),
421451
args...)
422452
end
453+
454+
function save_discretes!(integrator::DEIntegrator, cb::Union{ContinuousCallback, DiscreteCallback})
455+
for idx in cb.discrete_save_idxs
456+
save_discretes!(integrator, idx)
457+
end
458+
end
459+
460+
function save_discretes!(integrator::DEIntegrator, cb::VectorContinuousCallback)
461+
for idx in cb.all_save_idxs
462+
save_discretes!(integrator, idx)
463+
end
464+
end
465+
466+
function save_discretes!(integrator::DEIntegrator, cb::VectorContinuousCallback, i)
467+
for idx in cb.discrete_save_idxs[i]
468+
save_discretes!(integrator, idx)
469+
end
470+
end

0 commit comments

Comments
 (0)