1
+ """
2
+ abstract type ComponentCallback{C,A} end
3
+
4
+ Abstract type for a component based callback. A component callback
5
+ bundles a [`ComponentCondition`](@ref) as well as a [`ComponentAffect`](@ref)
6
+ which can be then tied to a component model using [`add_callback!`](@ref) or
7
+ [`set_callback!`](@ref).
8
+
9
+ On a Network level, you can automaticially create network wide `CallbackSet`s using
10
+ [`get_callbacks`](@ref).
11
+
12
+ See
13
+ [`ContinousComponentCallback`](@ref) and [`VectorContinousComponentCallback`](@ref) for concrete
14
+ implemenations of this abstract type.
15
+ """
1
16
abstract type ComponentCallback{C,A} end
2
17
18
+ """
19
+ ComponentCondition(f::Function, sym, psym)
20
+
21
+ Creates a callback condition for a [`ComponentCallback`].
22
+ - `f`: The condition function. Must be a function of the form `out=f(u, p, t)` when used
23
+ for [`ContinouseComponentcallback`](@ref) or `f!(out, u, p, t)` when used for
24
+ [`VectorContinousComponentCallback`](@ref).
25
+ - Arguments of `f`
26
+ - `u`: The current value of the selecte `sym` states, provided as a [`SymbolicView`](@ref) object.
27
+ - `p`: The current value of the selected `psym` parameters.
28
+ - `t`: The current simulation time.
29
+ - `sym`: A vector or tuple of symbols, which represent **states** (including
30
+ inputs, outputs, observed) of the component model. Determines, which states will
31
+ be available thorugh parameter `u` in the callback condition function `f`.
32
+ - `psym`: A vector or tuple of symbols, which represetn **parameters** of the component mode.
33
+ Determines, which parameters will be available in the condition function `f`
34
+
35
+ # Example
36
+ Consider a component model with states `[:u1, :u2]`, inputs `[:i]`, outputs
37
+ `[:o]` and parameters `[:p1, :p2]`.
38
+
39
+ ComponentCondition([:u1, :o], [:p1]) do (u, p, t)
40
+ # access states symbolicially or via int index
41
+ u[:u1] == u[1]
42
+ u[:o] == u[2]
43
+ p[:p1] == p[1]
44
+ # the states/prameters `:u2`, `:i` and `:p2` are not available as
45
+ # they are not listed in the `sym` and `psym` arguments.
46
+ end
47
+ """
3
48
struct ComponentCondition{C,DIM,PDIM}
4
49
f:: C
5
50
sym:: NTuple{DIM,Symbol}
@@ -9,6 +54,38 @@ struct ComponentCondition{C,DIM,PDIM}
9
54
end
10
55
end
11
56
57
+ """
58
+ ComponentCondition(f::Function, sym, psym)
59
+
60
+ Creates a callback condition for a [`ComponentCallback`].
61
+ - `f`: The affect function. Must be a function of the form `f(u, p, [event_idx], ctx)` where `event_idx`
62
+ is only available in [`VectorContinouseComponentcallback`](@ref).
63
+ - Arguments of `f`
64
+ - `u`: The current (mutable) value of the selected `sym` states, provided as a [`SymbolicView`](@ref) object.
65
+ - `p`: The current (mutalbe) value of the selected `psym` parameters.
66
+ - `event_idx`: The current event index, i.e. which `out` element triggerd in case of [`VectorContinousComponentCallback`](@ref).
67
+ - `ctx::NamedTuple` a named tuple with context variables.
68
+ - `ctx.model`: a referenc to the ocmponent model
69
+ - `ctx.vidx`/ctx.eidx: The index of the vertex/edge model.
70
+ - `ctx.src`/`ctx.dst`: src and dst indices (only for edge models).
71
+ - `ctx.integrator`: The integrator object.
72
+ - `ctx.t=ctx.integrator.t`: The current simulation time.
73
+ - `sym`: A vector or tuple of symbols, which represent **states** (**excluding**
74
+ inputs, outputs, observed) of the component model. Determines, which states will
75
+ be available thorugh parameter `u` in the callback condition function `f`.
76
+ - `psym`: A vector or tuple of symbols, which represetn **parameters** of the component mode.
77
+ Determines, which parameters will be available in the condition function `f`
78
+
79
+ # Example
80
+ Consider a component model with states `[:u1, :u2]`, inputs `[:i]`, outputs
81
+ `[:o]` and parameters `[:p1, :p2]`.
82
+
83
+ ComponentAffect([:u1, :o], [:p1]) do (u, p, ctx)
84
+ u[:u1] = 0 # change the state
85
+ p[:p1] = 1 # change the parameter
86
+ @info "Changed :u1 and :p1 on vertex \$ (ctx.vidx)" # access context
87
+ end
88
+ """
12
89
struct ComponentAffect{A,DIM,PDIM}
13
90
f:: A
14
91
sym:: NTuple{DIM,Symbol}
@@ -18,6 +95,18 @@ struct ComponentAffect{A,DIM,PDIM}
18
95
end
19
96
end
20
97
98
+ """
99
+ ContinousComponentCallback(condition, affect; kwargs...)
100
+
101
+ Connect a [`ComponentCondition`](@ref) and a [`ComponentAffect`)[@ref] to a
102
+ continous callback which can be attached to a component model using
103
+ [`add_callback!`](@ref) or [`set_callback!`](@ref).
104
+
105
+ The `kwargs` will be forwarded to the `VectorContinuousCallback` when the component based
106
+ callbacks are collected for the whole network using `get_callbacks`.
107
+ [`DiffEq.jl docs`](https://docs.sciml.ai/DiffEqDocs/stable/features/callback_functions/)
108
+ for available options.
109
+ """
21
110
struct ContinousComponentCallback{C,A,CDIM,CPDIM,ADIM,APDIM} <: ComponentCallback{C,A}
22
111
condition:: ComponentCondition{C,CDIM,CPDIM}
23
112
affect:: ComponentAffect{A,ADIM,APDIM}
@@ -30,6 +119,21 @@ function ContinousComponentCallback(condition, affect; kwargs...)
30
119
ContinousComponentCallback (condition, affect, NamedTuple (kwargs))
31
120
end
32
121
122
+ """
123
+ VectorContinousComponentCallback(condition, affect, len; kwargs...)
124
+
125
+ Connect a [`ComponentCondition`](@ref) and a [`ComponentAffect`)[@ref] to a
126
+ continous callback which can be attached to a component model using
127
+ [`add_callback!`](@ref) or [`set_callback!`](@ref). This vector version allows
128
+ for `condions` which have `len` output dimensions.
129
+ The `affect` will be triggered with the additional `event_idx` argument to know in which
130
+ dimension the zerocrossing was detected.
131
+
132
+ The `kwargs` will be forwarded to the `VectorContinuousCallback` when the component based
133
+ callbacks are collected for the whole network using `get_callbacks`.
134
+ [`DiffEq.jl docs`](https://docs.sciml.ai/DiffEqDocs/stable/features/callback_functions/)
135
+ for available options.
136
+ """
33
137
struct VectorContinousComponentCallback{C,A,CDIM,CPDIM,ADIM,APDIM} <: ComponentCallback{C,A}
34
138
condition:: ComponentCondition{C,CDIM,CPDIM}
35
139
affect:: ComponentAffect{A,ADIM,APDIM}
@@ -43,25 +147,43 @@ function VectorContinousComponentCallback(condition, affect, len; kwargs...)
43
147
VectorContinousComponentCallback (condition, affect, len, NamedTuple (kwargs))
44
148
end
45
149
46
- struct CallbackBatch{T<: ComponentCallback ,C,A,ST<: SymbolicIndex }
150
+
151
+ """
152
+ get_callbacks(nw::Network)::CallbackSet
153
+
154
+ Returns a `CallbackSet` composed of all the "component-based" callbacks in the metadata of the
155
+ Network components.
156
+ """
157
+ function get_callbacks (nw:: Network )
158
+ cbbs = collect_callbackbatches (nw)
159
+ if isempty (cbbs)
160
+ return nothing
161
+ elseif length (cbbs) == 1
162
+ return to_callback (only (cbbs))
163
+ else
164
+ CallbackSet (to_callback .(cbbs)... )
165
+ end
166
+ end
167
+ # ###
168
+ # ### batching of callbacks
169
+ # ###
170
+ struct CallbackBatch{T<: ComponentCallback ,C,ST<: SymbolicIndex }
47
171
nw:: Network
48
172
components:: Vector{ST}
49
173
callbacks:: Vector{T}
50
174
sublen:: Int # length of each callback
51
175
condition:: C
52
- affect:: A
53
176
end
54
177
function CallbackBatch (nw, components, callbacks)
55
178
if ! isconcretetype (eltype (components))
56
- components = Vector {typeof(first( components))} (components)
179
+ components = [c for c in components]
57
180
end
58
181
if ! isconcretetype (eltype (callbacks))
59
- callbacks = Vector {typeof(first( callbacks))} (callbacks)
182
+ callbacks = [cb for cb in callbacks]
60
183
end
61
184
sublen = eltype (callbacks) <: ContinousComponentCallback ? 1 : first (callbacks). len
62
185
condition = first (callbacks). condition. f
63
- affect = first (callbacks). affect. f
64
- CallbackBatch (nw, components, callbacks, sublen, condition, affect)
186
+ CallbackBatch (nw, components, callbacks, sublen, condition)
65
187
end
66
188
67
189
Base. length (cbb:: CallbackBatch ) = length (cbb. callbacks)
@@ -70,17 +192,18 @@ cbtype(cbb::CallbackBatch{T}) where {T} = T
70
192
71
193
condition_dim (cbb) = first (cbb. callbacks). condition. sym |> length
72
194
condition_pdim (cbb) = first (cbb. callbacks). condition. psym |> length
73
- affect_dim (cbb) = first ( cbb. callbacks) . affect. sym |> length
74
- affect_pdim (cbb) = first ( cbb. callbacks) . affect. psym |> length
195
+ affect_dim (cbb,i ) = cbb. callbacks[i] . affect. sym |> length
196
+ affect_pdim (cbb,i ) = cbb. callbacks[i] . affect. psym |> length
75
197
76
198
condition_urange (cbb, i) = (1 + (i- 1 )* condition_dim (cbb)) : i* condition_dim (cbb)
77
199
condition_prange (cbb, i) = (1 + (i- 1 )* condition_pdim (cbb)) : i* condition_pdim (cbb)
78
- affect_urange (cbb, i) = (1 + (i- 1 )* affect_dim (cbb) ) : i* affect_dim (cbb)
79
- affect_prange (cbb, i) = (1 + (i- 1 )* affect_pdim (cbb)) : i* affect_pdim (cbb)
200
+ affect_urange (cbb, i) = (1 + (i- 1 )* affect_dim (cbb,i ) ) : i* affect_dim (cbb,i )
201
+ affect_prange (cbb, i) = (1 + (i- 1 )* affect_pdim (cbb,i )) : i* affect_pdim (cbb,i )
80
202
81
203
condition_outrange (cbb, i) = (1 + (i- 1 )* cbb. sublen) : i* cbb. sublen
82
204
83
205
cbidx_from_outidx (cbb, outidx) = div (outidx- 1 , cbb. sublen) + 1
206
+ subidx_from_outidx (cbb, outidx) = mod (outidx, 1 : cbb. sublen)
84
207
85
208
function collect_c_or_a_indices (cbb:: CallbackBatch , c_or_a, u_or_p)
86
209
sidxs = SymbolicIndex[]
@@ -114,7 +237,6 @@ function collect_callbackbatches(nw)
114
237
push! (callbacks, cb)
115
238
end
116
239
end
117
-
118
240
idx_per_type = _find_identical (callbacks, 1 : length (components))
119
241
batches = CallbackBatch[]
120
242
for typeidx in idx_per_type
@@ -128,15 +250,15 @@ end
128
250
129
251
function batchequal (a:: ContinousComponentCallback , b:: ContinousComponentCallback )
130
252
batchequal (a. condition, b. condition) || return false
131
- batchequal (a. affect, b. affect) || return false
253
+ # batchequal(a.affect, b.affect) || return false
132
254
batchequal (a. kwargs, b. kwargs) || return false
133
255
return true
134
256
end
135
257
function batchequal (a:: VectorContinousComponentCallback , b:: VectorContinousComponentCallback )
136
258
batchequal (a. condition, b. condition) || return false
137
- batchequal (a. affect, b. affect) || return false
259
+ # batchequal(a.affect, b.affect) || return false
138
260
batchequal (a. kwargs, b. kwargs) || return false
139
- batchequal ( a. len, b. len) || return false
261
+ a. len == b. len || return false
140
262
return true
141
263
end
142
264
function batchequal (a:: T , b:: T ) where {T <: Union{ComponentCondition, ComponentAffect} }
@@ -151,7 +273,19 @@ function batchequal(a::NamedTuple, b::NamedTuple)
151
273
return true
152
274
end
153
275
154
- function batch_condition (cbb)
276
+ """
277
+ to_callback(cbb:CallbackBatch)
278
+
279
+ Generate a `VectorContinuousCallback` from a callback batch.
280
+ """
281
+ function to_callback (cbb:: CallbackBatch )
282
+ kwargs = first (cbb. callbacks). kwargs
283
+ cond = _batch_condition (cbb)
284
+ affect = _batch_affect (cbb)
285
+ len = cbb. sublen * length (cbb. callbacks)
286
+ VectorContinuousCallback (cond, affect, len; kwargs... )
287
+ end
288
+ function _batch_condition (cbb:: CallbackBatch )
155
289
usymidxs = collect_c_or_a_indices (cbb, :condition , :sym )
156
290
psymidxs = collect_c_or_a_indices (cbb, :condition , :psym )
157
291
ucache = DiffCache (zeros (length (usymidxs)), 12 )
@@ -190,12 +324,14 @@ function batch_condition(cbb)
190
324
elseif cbtype (cbb) <: VectorContinousComponentCallback
191
325
@views _out = out[condition_outrange (cbb, i)]
192
326
cbb. condition (_out, _u, _p, t)
327
+ else
328
+ error ()
193
329
end
194
330
end
195
331
nothing
196
332
end
197
333
end
198
- function batch_affect (cbb)
334
+ function _batch_affect (cbb:: CallbackBatch )
199
335
usymidxs = collect_c_or_a_indices (cbb, :affect , :sym )
200
336
psymidxs = collect_c_or_a_indices (cbb, :affect , :psym )
201
337
@@ -226,7 +362,15 @@ function batch_affect(cbb)
226
362
227
363
uhash = hash (uv)
228
364
phash = hash (pv)
229
- cbb. affect (_u, _p, get_ctx (integrator, cbb, cbb. components[i]))
365
+ ctx = _get_ctx (integrator, cbb, cbb. components[i])
366
+ if cbtype (cbb) <: ContinousComponentCallback
367
+ cbb. callbacks[i]. affect. f (_u, _p, ctx)
368
+ elseif cbtype (cbb) <: VectorContinousComponentCallback
369
+ num = subidx_from_outidx (cbb, outidx)
370
+ cbb. callbacks[i]. affect. f (_u, _p, num, ctx)
371
+ else
372
+ error ()
373
+ end
230
374
pchanged = hash (pv) != phash
231
375
uchanged = hash (uv) != uhash
232
376
@@ -235,18 +379,36 @@ function batch_affect(cbb)
235
379
end
236
380
end
237
381
238
- get_ctx (cbb, i:: Int ) = get_ctx (cbb, cbb. components[i])
239
- function get_ctx (integrator, cbb, sym:: VIndex )
382
+ _get_ctx (cbb, i:: Int ) = _get_ctx (cbb, cbb. components[i])
383
+ function _get_ctx (integrator, cbb, sym:: VIndex )
240
384
idx = sym. compidx
241
385
(; integrator, t= integrator. t, model= cbb. nw. im. vertexm[idx], vidx= idx)
242
386
end
243
- function get_ctx (integrator, cbb, sym:: EIndex )
387
+ function _get_ctx (integrator, cbb, sym:: EIndex )
244
388
idx = sym. compidx
245
389
edge = cbb. nw. im. edgevec[idx]
246
390
(; integrator, t= integrator. t, model= cbb. nw. im. edgem[idx], eidx= idx, src= edge. src, dst= edge. dst)
247
391
end
248
392
249
- struct SymbolicView{N,VT}
393
+ # ###
394
+ # ### SymbolicView helper type
395
+ # ###
396
+ """
397
+ SymbolicView{N,VT} <: AbstractVetor{VT}
398
+
399
+ Is a (smallish) fixed size vector type with named dimensions.
400
+ Its main purpose is to allow named acces to variables in
401
+ [`ComponentCondition`](@ref) and [`ComponentAffect`](@ref) functions.
402
+
403
+ ```jldoctest
404
+ julia> sv = SymbolicView([1,2,3],(:a,:b,:c))
405
+ SymbolicView{3, Vector{Int64}}([1, 2, 3], (:a, :b, :c))
406
+
407
+ julia> sv[1] == sv[:a]
408
+ true
409
+ ```
410
+ """
411
+ struct SymbolicView{N,VT} <: AbstractVector{VT}
250
412
v:: VT
251
413
syms:: NTuple{N,Symbol}
252
414
end
270
432
_sym_to_int (x:: SymbolicView , idx:: Int ) = idx
271
433
_sym_to_int (x:: SymbolicView , idx) = _sym_to_int .(Ref (x), idx)
272
434
435
+ # ###
436
+ # ### Internal function to check cb compat when added as metadata
437
+ # ###
273
438
assert_cb_compat (comp:: ComponentModel , t:: Tuple ) = assert_cb_compat .(Ref (comp), t)
274
439
function assert_cb_compat (comp:: ComponentModel , cb)
275
- all_obssym = Set (comp. obssym) ∪ insym_all (comp) ∪ outsym_flat (comp)
440
+ all_obssym = Set (sym (comp)) ∪ Set ( comp. obssym) ∪ insym_all (comp) ∪ outsym_flat (comp)
276
441
pcond = s -> s in comp. psym
277
442
ucond_cond = s -> s in all_obssym
278
443
ucond_affect = s -> s in comp. sym
@@ -295,22 +460,3 @@ function assert_cb_compat(comp::ComponentModel, cb)
295
460
end
296
461
cb
297
462
end
298
-
299
- function to_callback (cbb:: CallbackBatch )
300
- kwargs = first (cbb. callbacks). kwargs
301
- cond = batch_condition (cbb)
302
- affect = batch_affect (cbb)
303
- len = cbb. sublen * length (cbb. callbacks)
304
- VectorContinuousCallback (cond, affect, len; kwargs... )
305
- end
306
-
307
- function get_callbacks (nw:: Network )
308
- cbbs = collect_callbackbatches (nw)
309
- if isempty (cbbs)
310
- return nothing
311
- elseif length (cbbs) == 1
312
- return to_callback (only (cbbs))
313
- else
314
- CallbackSet (to_callback .(cbbs)... )
315
- end
316
- end
0 commit comments