Skip to content

Commit 0f12878

Browse files
committed
add docs and tests on Discrete/PresetTime CB
1 parent 01a6190 commit 0f12878

File tree

5 files changed

+65
-14
lines changed

5 files changed

+65
-14
lines changed

docs/src/API.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ set_bounds!
130130
NetworkDynamics.ComponentCallback
131131
ContinousComponentCallback
132132
VectorContinousComponentCallback
133+
DiscreteComponentCallback
134+
PresetTimeComponentCallback
133135
ComponentCondition
134136
ComponentAffect
135137
SymbolicView

docs/src/callbacks.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,18 @@ refer to the [Cascading Failure](@ref) example.
1616
In practice, events often act locally, meaning they only depend and act on a
1717
specific component or type of component. `NetworkDynamics.jl` provides a way of
1818
defining those callbacks on a component level and automaticially combine them into performant
19-
[`VectorContinuousCallback`](@extref SciMLBase.VectorContinuousCallback) for the whole network.
19+
[`VectorContinuousCallback`](@extref SciMLBase.VectorContinuousCallback) and [`DiscreteCallback`](@extref SciMLBase.DiscreteCallback) for the whole network.
2020

21-
The main entry points are the types [`ContinousComponentCallback`](@ref) and
22-
[`VectorContinousComponentCallback`](@ref). Both of those objects combine a [`ComponentCondition`](@ref)
23-
with an [`ComponentAffect`](@ref).
24-
The "normal" `ContinousComponentCallback` has a condition which returns a single value. The corresponding affect is triggered when the return value hits zero.
21+
The main entry points are the types [`ContinousComponentCallback`](@ref),
22+
[`VectorContinousComponentCallback`](@ref) and [`DiscreteComponentCallback`](@ref). All of those objects combine a [`ComponentCondition`](@ref) with an [`ComponentAffect`](@ref).
23+
24+
The "normal" [`ContinousComponentCallback`](@ref) and [`DiscreteComponentCallback`](@ref) have a condition which returns a single value. The corresponding affect is triggered when the return value hits zero.
2525
In contrast, the "vector" version has an in-place condition which writes `len` outputs. When any of those outputs hits zero, the affect is triggered with an additional argument `event_idx` which tells the effect which dimension encountered the zerocrossing.
2626

27+
There is a special type [`PresetTimeComponentCallback`](@ref) which has no explicit condition and triggers the affect at given times.
28+
This internally generates a [`PresetTimeCallback`](@ref DiffEqCallbacks.PresetTimeCallback) object from `DiffEqCallbacks.jl`.
29+
30+
2731
### Defining the Callback
2832
To construct a condition function, you need to tell network dynamics which states and parameters you'd like to "observe" within the condition. Within the actual condition, those states will be made available:
2933
```julia

src/NetworkDynamics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using ForwardDiff: ForwardDiff
2121
using Printf: @sprintf
2222
using Random: Random
2323
using Static: Static, StaticInt
24-
using SciMLBase: VectorContinuousCallback, CallbackSet
24+
using SciMLBase: VectorContinuousCallback, CallbackSet, DiscreteCallback
2525
using DiffEqCallbacks: DiffEqCallbacks
2626

2727
@static if VERSION v"1.11.0-0"

src/callbacks.jl

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ abstract type ComponentCallback end
1919
ComponentCondition(f::Function, sym, psym)
2020
2121
Creates a callback condition for a [`ComponentCallback`].
22-
- `f`: The condition function. Must be a function of the form `out=f(u, p, t)` when used
23-
for [`ContinousComponentCallback`](@ref) or `f!(out, u, p, t)` when used for
22+
- `f`: The condition function. Must be a function of the form `out=f(u, p, t)`
23+
when used for [`ContinousComponentCallback`](@ref) or
24+
[`DiscreteComponentcallback`](@ref) and `f!(out, u, p, t)` when used for
2425
[`VectorContinousComponentCallback`](@ref).
2526
- Arguments of `f`
2627
- `u`: The current value of the selecte `sym` states, provided as a [`SymbolicView`](@ref) object.
@@ -130,7 +131,7 @@ The `affect` will be triggered with the additional `event_idx` argument to know
130131
dimension the zerocrossing was detected.
131132
132133
The `kwargs` will be forwarded to the `VectorContinuousCallback` when the component based
133-
callbacks are collected for the whole network using `get_callbacks`.
134+
callbacks are collected for the whole network using [`get_callbacks(::Network)`](@ref).
134135
[`DiffEq.jl docs`](https://docs.sciml.ai/DiffEqDocs/stable/features/callback_functions/)
135136
for available options.
136137
"""
@@ -147,6 +148,21 @@ function VectorContinousComponentCallback(condition, affect, len; kwargs...)
147148
VectorContinousComponentCallback(condition, affect, len, NamedTuple(kwargs))
148149
end
149150

151+
"""
152+
DiscreteComponentCallback(condition, affect; kwargs...)
153+
154+
Connect a [`ComponentCondition`](@ref) and a [`ComponentAffect`)[@ref] to a
155+
discrete callback which can be attached to a component model using
156+
[`add_callback!`](@ref) or [`set_callback!`](@ref).
157+
158+
Note that the `condition` function returns a boolean value, as the discrete
159+
callback perform no rootfinding.
160+
161+
The `kwargs` will be forwarded to the `DiscreteCallback` when the component based
162+
callbacks are collected for the whole network using [`get_callbacks(::Network)`](@ref).
163+
[`DiffEq.jl docs`](https://docs.sciml.ai/DiffEqDocs/stable/features/callback_functions/)
164+
for available options.
165+
"""
150166
struct DiscreteComponentCallback{C<:ComponentCondition,A<:ComponentAffect} <: ComponentCallback
151167
condition::C
152168
affect::A
@@ -156,6 +172,20 @@ function DiscreteComponentCallback(condition, affect; kwargs...)
156172
DiscreteComponentCallback(condition, affect, NamedTuple(kwargs))
157173
end
158174

175+
"""
176+
PresetTimeComponentCallback(ts, affect; kwargs...)
177+
178+
Tirgger a [`ComponentAffect`](@ref) at given timesteps `ts` in discrete
179+
callback, which can be attached to a component model using
180+
[`add_callback!`](@ref) or [`set_callback!`](@ref).
181+
182+
The `kwargs` will be forwarded to the [`PresetTimeCallback`](@ref DiffEqCallbacks.PresetTimeCallback)
183+
when the component based callbacks are collected for the whole network using
184+
[`get_callbacks(::Network)`](@ref).
185+
186+
The `PresetTimeCallback` will take care of adding the timesteps to the solver, ensuring to
187+
exactly trigger at the correct times.
188+
"""
159189
struct PresetTimeComponentCallback{T,A} <: ComponentCallback
160190
ts::T
161191
affect::A
@@ -417,12 +447,12 @@ end
417447
function _batch_condition(dcw::DiscreteCallbackWrapper)
418448
uidxtype = dcw.component isa EIndex ? EIndex : VIndex
419449
pidxtype = dcw.component isa EIndex ? EPIndex : VPIndex
420-
usymidxs = uidxtype(dcw.component.compidx, dcw.callback.cond.sym)
421-
psymidxs = pidxtype(dcw.component.compidx, dcw.callback.cond.psym)
450+
usymidxs = uidxtype(dcw.component.compidx, dcw.callback.condition.sym)
451+
psymidxs = pidxtype(dcw.component.compidx, dcw.callback.condition.psym)
422452
ucache = DiffCache(zeros(length(usymidxs)), 12)
423453

424-
obsf = SII.observed(ccw.nw, usymidxs)
425-
pidxs = SII.parameter_index.(Ref(ccw.nw), psymidxs)
454+
obsf = SII.observed(dcw.nw, usymidxs)
455+
pidxs = SII.parameter_index.(Ref(dcw.nw), psymidxs)
426456

427457
(u, t, integrator) -> begin
428458
us = PreallocationTools.get_tmp(ucache, u)

test/callbacks_test.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,29 @@ end
7676
b = @b $batchcond($out, $u, NaN, $integrator)
7777
@test b.allocs == 0
7878

79+
# test the preste time callback
7980
tripfirst = PresetTimeComponentCallback(1.0, affect) # reuse the same affect
8081
add_callback!(nw[EIndex(5)], tripfirst)
8182

82-
nwcb = NetworkDynamics.get_callbacks(nw)
83+
# add a useless discrete callback
84+
useless_triggertime = Ref{Float64}(0.0)
85+
usless_cond = ComponentCondition([:P, :₋P, :srcθ], [:limit, :K]) do u, p, t
86+
t > 0.1 && iszero(useless_triggertime[])
87+
end
88+
usless_affect = ComponentAffect([], [:limit, :K]) do u, p, ctx
89+
@info "Usless effect triggered at $(ctx.t)"
90+
useless_triggertime[] = ctx.t
91+
end
92+
useless_cb = DiscreteComponentCallback(usless_cond, usless_affect)
93+
add_callback!(nw[EIndex(1)], useless_cb)
94+
95+
nwcb = NetworkDynamics.get_callbacks(nw);
8396
s0 = NWState(nw)
8497
prob = ODEProblem(nw, uflat(s0), (0,6), copy(pflat(s0)), callback=nwcb)
8598
sol = solve(prob, Tsit5());
8699

100+
@test 0.1 < useless_triggertime[] <= 1.0
101+
87102
@test tripi == [5,7,4,1,3,2]
88103
tref = [1, 2.247676397005474, 2.502523192233235, 3.1947647115093654, 3.3380530127462587, 3.4042696241577888]
89104
@test maximum(abs.(tript - tref)) < 1e-5

0 commit comments

Comments
 (0)