Skip to content

Commit 21652a5

Browse files
committed
add docs, fix callback tests
1 parent aa3e801 commit 21652a5

File tree

4 files changed

+277
-50
lines changed

4 files changed

+277
-50
lines changed

src/NetworkDynamics.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ include("metadata.jl")
7777

7878
export ComponentCondition, ComponentAffect
7979
export ContinousComponentCallback, VectorContinousComponentCallback
80+
export SymbolicView, get_callbacks
8081
include("callbacks.jl")
8182

8283
using NonlinearSolve: AbstractNonlinearSolveAlgorithm, NonlinearFunction

src/callbacks.jl

Lines changed: 187 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,50 @@
1+
"""
2+
abstract type ComponentCallback{C,A} end
3+
4+
Abstract type for a component based callback. A component callback
5+
bundles a [`ComponentCondition`](@ref) as well as a [`ComponentAffect`](@ref)
6+
which can be then tied to a component model using [`add_callback!`](@ref) or
7+
[`set_callback!`](@ref).
8+
9+
On a Network level, you can automaticially create network wide `CallbackSet`s using
10+
[`get_callbacks`](@ref).
11+
12+
See
13+
[`ContinousComponentCallback`](@ref) and [`VectorContinousComponentCallback`](@ref) for concrete
14+
implemenations of this abstract type.
15+
"""
116
abstract type ComponentCallback{C,A} end
217

18+
"""
19+
ComponentCondition(f::Function, sym, psym)
20+
21+
Creates a callback condition for a [`ComponentCallback`].
22+
- `f`: The condition function. Must be a function of the form `out=f(u, p, t)` when used
23+
for [`ContinouseComponentcallback`](@ref) or `f!(out, u, p, t)` when used for
24+
[`VectorContinousComponentCallback`](@ref).
25+
- Arguments of `f`
26+
- `u`: The current value of the selecte `sym` states, provided as a [`SymbolicView`](@ref) object.
27+
- `p`: The current value of the selected `psym` parameters.
28+
- `t`: The current simulation time.
29+
- `sym`: A vector or tuple of symbols, which represent **states** (including
30+
inputs, outputs, observed) of the component model. Determines, which states will
31+
be available thorugh parameter `u` in the callback condition function `f`.
32+
- `psym`: A vector or tuple of symbols, which represetn **parameters** of the component mode.
33+
Determines, which parameters will be available in the condition function `f`
34+
35+
# Example
36+
Consider a component model with states `[:u1, :u2]`, inputs `[:i]`, outputs
37+
`[:o]` and parameters `[:p1, :p2]`.
38+
39+
ComponentCondition([:u1, :o], [:p1]) do (u, p, t)
40+
# access states symbolicially or via int index
41+
u[:u1] == u[1]
42+
u[:o] == u[2]
43+
p[:p1] == p[1]
44+
# the states/prameters `:u2`, `:i` and `:p2` are not available as
45+
# they are not listed in the `sym` and `psym` arguments.
46+
end
47+
"""
348
struct ComponentCondition{C,DIM,PDIM}
449
f::C
550
sym::NTuple{DIM,Symbol}
@@ -9,6 +54,38 @@ struct ComponentCondition{C,DIM,PDIM}
954
end
1055
end
1156

57+
"""
58+
ComponentCondition(f::Function, sym, psym)
59+
60+
Creates a callback condition for a [`ComponentCallback`].
61+
- `f`: The affect function. Must be a function of the form `f(u, p, [event_idx], ctx)` where `event_idx`
62+
is only available in [`VectorContinouseComponentcallback`](@ref).
63+
- Arguments of `f`
64+
- `u`: The current (mutable) value of the selected `sym` states, provided as a [`SymbolicView`](@ref) object.
65+
- `p`: The current (mutalbe) value of the selected `psym` parameters.
66+
- `event_idx`: The current event index, i.e. which `out` element triggerd in case of [`VectorContinousComponentCallback`](@ref).
67+
- `ctx::NamedTuple` a named tuple with context variables.
68+
- `ctx.model`: a referenc to the ocmponent model
69+
- `ctx.vidx`/ctx.eidx: The index of the vertex/edge model.
70+
- `ctx.src`/`ctx.dst`: src and dst indices (only for edge models).
71+
- `ctx.integrator`: The integrator object.
72+
- `ctx.t=ctx.integrator.t`: The current simulation time.
73+
- `sym`: A vector or tuple of symbols, which represent **states** (**excluding**
74+
inputs, outputs, observed) of the component model. Determines, which states will
75+
be available thorugh parameter `u` in the callback condition function `f`.
76+
- `psym`: A vector or tuple of symbols, which represetn **parameters** of the component mode.
77+
Determines, which parameters will be available in the condition function `f`
78+
79+
# Example
80+
Consider a component model with states `[:u1, :u2]`, inputs `[:i]`, outputs
81+
`[:o]` and parameters `[:p1, :p2]`.
82+
83+
ComponentAffect([:u1, :o], [:p1]) do (u, p, ctx)
84+
u[:u1] = 0 # change the state
85+
p[:p1] = 1 # change the parameter
86+
@info "Changed :u1 and :p1 on vertex \$(ctx.vidx)" # access context
87+
end
88+
"""
1289
struct ComponentAffect{A,DIM,PDIM}
1390
f::A
1491
sym::NTuple{DIM,Symbol}
@@ -18,6 +95,18 @@ struct ComponentAffect{A,DIM,PDIM}
1895
end
1996
end
2097

98+
"""
99+
ContinousComponentCallback(condition, affect; kwargs...)
100+
101+
Connect a [`ComponentCondition`](@ref) and a [`ComponentAffect`)[@ref] to a
102+
continous callback which can be attached to a component model using
103+
[`add_callback!`](@ref) or [`set_callback!`](@ref).
104+
105+
The `kwargs` will be forwarded to the `VectorContinuousCallback` when the component based
106+
callbacks are collected for the whole network using `get_callbacks`.
107+
[`DiffEq.jl docs`](https://docs.sciml.ai/DiffEqDocs/stable/features/callback_functions/)
108+
for available options.
109+
"""
21110
struct ContinousComponentCallback{C,A,CDIM,CPDIM,ADIM,APDIM} <: ComponentCallback{C,A}
22111
condition::ComponentCondition{C,CDIM,CPDIM}
23112
affect::ComponentAffect{A,ADIM,APDIM}
@@ -30,6 +119,21 @@ function ContinousComponentCallback(condition, affect; kwargs...)
30119
ContinousComponentCallback(condition, affect, NamedTuple(kwargs))
31120
end
32121

122+
"""
123+
VectorContinousComponentCallback(condition, affect, len; kwargs...)
124+
125+
Connect a [`ComponentCondition`](@ref) and a [`ComponentAffect`)[@ref] to a
126+
continous callback which can be attached to a component model using
127+
[`add_callback!`](@ref) or [`set_callback!`](@ref). This vector version allows
128+
for `condions` which have `len` output dimensions.
129+
The `affect` will be triggered with the additional `event_idx` argument to know in which
130+
dimension the zerocrossing was detected.
131+
132+
The `kwargs` will be forwarded to the `VectorContinuousCallback` when the component based
133+
callbacks are collected for the whole network using `get_callbacks`.
134+
[`DiffEq.jl docs`](https://docs.sciml.ai/DiffEqDocs/stable/features/callback_functions/)
135+
for available options.
136+
"""
33137
struct VectorContinousComponentCallback{C,A,CDIM,CPDIM,ADIM,APDIM} <: ComponentCallback{C,A}
34138
condition::ComponentCondition{C,CDIM,CPDIM}
35139
affect::ComponentAffect{A,ADIM,APDIM}
@@ -43,25 +147,43 @@ function VectorContinousComponentCallback(condition, affect, len; kwargs...)
43147
VectorContinousComponentCallback(condition, affect, len, NamedTuple(kwargs))
44148
end
45149

46-
struct CallbackBatch{T<:ComponentCallback,C,A,ST<:SymbolicIndex}
150+
151+
"""
152+
get_callbacks(nw::Network)::CallbackSet
153+
154+
Returns a `CallbackSet` composed of all the "component-based" callbacks in the metadata of the
155+
Network components.
156+
"""
157+
function get_callbacks(nw::Network)
158+
cbbs = collect_callbackbatches(nw)
159+
if isempty(cbbs)
160+
return nothing
161+
elseif length(cbbs) == 1
162+
return to_callback(only(cbbs))
163+
else
164+
CallbackSet(to_callback.(cbbs)...)
165+
end
166+
end
167+
####
168+
#### batching of callbacks
169+
####
170+
struct CallbackBatch{T<:ComponentCallback,C,ST<:SymbolicIndex}
47171
nw::Network
48172
components::Vector{ST}
49173
callbacks::Vector{T}
50174
sublen::Int # length of each callback
51175
condition::C
52-
affect::A
53176
end
54177
function CallbackBatch(nw, components, callbacks)
55178
if !isconcretetype(eltype(components))
56-
components = Vector{typeof(first(components))}(components)
179+
components = [c for c in components]
57180
end
58181
if !isconcretetype(eltype(callbacks))
59-
callbacks = Vector{typeof(first(callbacks))}(callbacks)
182+
callbacks = [cb for cb in callbacks]
60183
end
61184
sublen = eltype(callbacks) <: ContinousComponentCallback ? 1 : first(callbacks).len
62185
condition = first(callbacks).condition.f
63-
affect = first(callbacks).affect.f
64-
CallbackBatch(nw, components, callbacks, sublen, condition, affect)
186+
CallbackBatch(nw, components, callbacks, sublen, condition)
65187
end
66188

67189
Base.length(cbb::CallbackBatch) = length(cbb.callbacks)
@@ -70,17 +192,18 @@ cbtype(cbb::CallbackBatch{T}) where {T} = T
70192

71193
condition_dim(cbb) = first(cbb.callbacks).condition.sym |> length
72194
condition_pdim(cbb) = first(cbb.callbacks).condition.psym |> length
73-
affect_dim(cbb) = first(cbb.callbacks).affect.sym |> length
74-
affect_pdim(cbb) = first(cbb.callbacks).affect.psym |> length
195+
affect_dim(cbb,i) = cbb.callbacks[i].affect.sym |> length
196+
affect_pdim(cbb,i) = cbb.callbacks[i].affect.psym |> length
75197

76198
condition_urange(cbb, i) = (1 + (i-1)*condition_dim(cbb)) : i*condition_dim(cbb)
77199
condition_prange(cbb, i) = (1 + (i-1)*condition_pdim(cbb)) : i*condition_pdim(cbb)
78-
affect_urange(cbb, i) = (1 + (i-1)*affect_dim(cbb) ) : i*affect_dim(cbb)
79-
affect_prange(cbb, i) = (1 + (i-1)*affect_pdim(cbb)) : i*affect_pdim(cbb)
200+
affect_urange(cbb, i) = (1 + (i-1)*affect_dim(cbb,i) ) : i*affect_dim(cbb,i)
201+
affect_prange(cbb, i) = (1 + (i-1)*affect_pdim(cbb,i)) : i*affect_pdim(cbb,i)
80202

81203
condition_outrange(cbb, i) = (1 + (i-1)*cbb.sublen) : i*cbb.sublen
82204

83205
cbidx_from_outidx(cbb, outidx) = div(outidx-1, cbb.sublen) + 1
206+
subidx_from_outidx(cbb, outidx) = mod(outidx, 1:cbb.sublen)
84207

85208
function collect_c_or_a_indices(cbb::CallbackBatch, c_or_a, u_or_p)
86209
sidxs = SymbolicIndex[]
@@ -114,7 +237,6 @@ function collect_callbackbatches(nw)
114237
push!(callbacks, cb)
115238
end
116239
end
117-
118240
idx_per_type = _find_identical(callbacks, 1:length(components))
119241
batches = CallbackBatch[]
120242
for typeidx in idx_per_type
@@ -128,15 +250,15 @@ end
128250

129251
function batchequal(a::ContinousComponentCallback, b::ContinousComponentCallback)
130252
batchequal(a.condition, b.condition) || return false
131-
batchequal(a.affect, b.affect) || return false
253+
# batchequal(a.affect, b.affect) || return false
132254
batchequal(a.kwargs, b.kwargs) || return false
133255
return true
134256
end
135257
function batchequal(a::VectorContinousComponentCallback, b::VectorContinousComponentCallback)
136258
batchequal(a.condition, b.condition) || return false
137-
batchequal(a.affect, b.affect) || return false
259+
# batchequal(a.affect, b.affect) || return false
138260
batchequal(a.kwargs, b.kwargs) || return false
139-
batchequal(a.len, b.len) || return false
261+
a.len == b.len || return false
140262
return true
141263
end
142264
function batchequal(a::T, b::T) where {T <: Union{ComponentCondition, ComponentAffect}}
@@ -151,7 +273,19 @@ function batchequal(a::NamedTuple, b::NamedTuple)
151273
return true
152274
end
153275

154-
function batch_condition(cbb)
276+
"""
277+
to_callback(cbb:CallbackBatch)
278+
279+
Generate a `VectorContinuousCallback` from a callback batch.
280+
"""
281+
function to_callback(cbb::CallbackBatch)
282+
kwargs = first(cbb.callbacks).kwargs
283+
cond = _batch_condition(cbb)
284+
affect = _batch_affect(cbb)
285+
len = cbb.sublen * length(cbb.callbacks)
286+
VectorContinuousCallback(cond, affect, len; kwargs...)
287+
end
288+
function _batch_condition(cbb::CallbackBatch)
155289
usymidxs = collect_c_or_a_indices(cbb, :condition, :sym)
156290
psymidxs = collect_c_or_a_indices(cbb, :condition, :psym)
157291
ucache = DiffCache(zeros(length(usymidxs)), 12)
@@ -190,12 +324,14 @@ function batch_condition(cbb)
190324
elseif cbtype(cbb) <: VectorContinousComponentCallback
191325
@views _out = out[condition_outrange(cbb, i)]
192326
cbb.condition(_out, _u, _p, t)
327+
else
328+
error()
193329
end
194330
end
195331
nothing
196332
end
197333
end
198-
function batch_affect(cbb)
334+
function _batch_affect(cbb::CallbackBatch)
199335
usymidxs = collect_c_or_a_indices(cbb, :affect, :sym)
200336
psymidxs = collect_c_or_a_indices(cbb, :affect, :psym)
201337

@@ -226,7 +362,15 @@ function batch_affect(cbb)
226362

227363
uhash = hash(uv)
228364
phash = hash(pv)
229-
cbb.affect(_u, _p, get_ctx(integrator, cbb, cbb.components[i]))
365+
ctx = _get_ctx(integrator, cbb, cbb.components[i])
366+
if cbtype(cbb) <: ContinousComponentCallback
367+
cbb.callbacks[i].affect.f(_u, _p, ctx)
368+
elseif cbtype(cbb) <: VectorContinousComponentCallback
369+
num = subidx_from_outidx(cbb, outidx)
370+
cbb.callbacks[i].affect.f(_u, _p, num, ctx)
371+
else
372+
error()
373+
end
230374
pchanged = hash(pv) != phash
231375
uchanged = hash(uv) != uhash
232376

@@ -235,18 +379,36 @@ function batch_affect(cbb)
235379
end
236380
end
237381

238-
get_ctx(cbb, i::Int) = get_ctx(cbb, cbb.components[i])
239-
function get_ctx(integrator, cbb, sym::VIndex)
382+
_get_ctx(cbb, i::Int) = _get_ctx(cbb, cbb.components[i])
383+
function _get_ctx(integrator, cbb, sym::VIndex)
240384
idx = sym.compidx
241385
(; integrator, t=integrator.t, model=cbb.nw.im.vertexm[idx], vidx=idx)
242386
end
243-
function get_ctx(integrator, cbb, sym::EIndex)
387+
function _get_ctx(integrator, cbb, sym::EIndex)
244388
idx = sym.compidx
245389
edge = cbb.nw.im.edgevec[idx]
246390
(; integrator, t=integrator.t, model=cbb.nw.im.edgem[idx], eidx=idx, src=edge.src, dst=edge.dst)
247391
end
248392

249-
struct SymbolicView{N,VT}
393+
####
394+
#### SymbolicView helper type
395+
####
396+
"""
397+
SymbolicView{N,VT} <: AbstractVetor{VT}
398+
399+
Is a (smallish) fixed size vector type with named dimensions.
400+
Its main purpose is to allow named acces to variables in
401+
[`ComponentCondition`](@ref) and [`ComponentAffect`](@ref) functions.
402+
403+
```jldoctest
404+
julia> sv = SymbolicView([1,2,3],(:a,:b,:c))
405+
SymbolicView{3, Vector{Int64}}([1, 2, 3], (:a, :b, :c))
406+
407+
julia> sv[1] == sv[:a]
408+
true
409+
```
410+
"""
411+
struct SymbolicView{N,VT} <: AbstractVector{VT}
250412
v::VT
251413
syms::NTuple{N,Symbol}
252414
end
@@ -270,9 +432,12 @@ end
270432
_sym_to_int(x::SymbolicView, idx::Int) = idx
271433
_sym_to_int(x::SymbolicView, idx) = _sym_to_int.(Ref(x), idx)
272434

435+
####
436+
#### Internal function to check cb compat when added as metadata
437+
####
273438
assert_cb_compat(comp::ComponentModel, t::Tuple) = assert_cb_compat.(Ref(comp), t)
274439
function assert_cb_compat(comp::ComponentModel, cb)
275-
all_obssym = Set(comp.obssym) insym_all(comp) outsym_flat(comp)
440+
all_obssym = Set(sym(comp)) Set(comp.obssym) insym_all(comp) outsym_flat(comp)
276441
pcond = s -> s in comp.psym
277442
ucond_cond = s -> s in all_obssym
278443
ucond_affect = s -> s in comp.sym
@@ -295,22 +460,3 @@ function assert_cb_compat(comp::ComponentModel, cb)
295460
end
296461
cb
297462
end
298-
299-
function to_callback(cbb::CallbackBatch)
300-
kwargs = first(cbb.callbacks).kwargs
301-
cond = batch_condition(cbb)
302-
affect = batch_affect(cbb)
303-
len = cbb.sublen * length(cbb.callbacks)
304-
VectorContinuousCallback(cond, affect, len; kwargs...)
305-
end
306-
307-
function get_callbacks(nw::Network)
308-
cbbs = collect_callbackbatches(nw)
309-
if isempty(cbbs)
310-
return nothing
311-
elseif length(cbbs) == 1
312-
return to_callback(only(cbbs))
313-
else
314-
CallbackSet(to_callback.(cbbs)...)
315-
end
316-
end

src/metadata.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ See also [`add_callback!`](@ref).
207207
"""
208208
function set_callback!(c::ComponentModel, cb; check=true)
209209
if !(cb isa ComponentCallback) && !(cb isa NTuple{N, <:ComponentCallback} where N)
210-
throw(ArgumentError("Callback must be a ComponentCallback or a tuple of ComponentCallbacks"))
210+
throw(ArgumentError("Callback must be a ComponentCallback or a tuple of ComponentCallbacks, got $(typeof(cb))."))
211211
end
212212
check && assert_cb_compat(c, cb)
213213
set_metadata!(c, :callback, cb)

0 commit comments

Comments
 (0)