@@ -3,23 +3,84 @@ get_continuous_events(sys::AbstractSystem) = Equation[]
3
3
get_continuous_events (sys:: AbstractODESystem ) = getfield (sys, :continuous_events )
4
4
has_continuous_events (sys:: AbstractSystem ) = isdefined (sys, :continuous_events )
5
5
6
- has_discrete_events (sys:: AbstractSystem ) = isdefined (sys, :discrete_events )
6
+ has_discrete_events (sys:: AbstractSystem ) = isdefined (sys, :discrete_events ) && length (sys . discrete_events) > 0
7
7
function get_discrete_events (sys:: AbstractSystem )
8
8
has_discrete_events (sys) || return SymbolicDiscreteCallback[]
9
9
getfield (sys, :discrete_events )
10
10
end
11
11
12
+ struct FunctionalAffect
13
+ f:: Function
14
+ sts:: Vector
15
+ sts_syms:: Vector{Symbol}
16
+ pars:: Vector
17
+ pars_syms:: Vector{Symbol}
18
+ ctx
19
+ FunctionalAffect (f, sts, sts_syms, pars, pars_syms, ctx = nothing ) = new (f,sts, sts_syms, pars, pars_syms, ctx)
20
+ end
21
+
22
+ function FunctionalAffect (f, sts, pars, ctx = nothing )
23
+ # sts & pars contain either pairs: resistor.R => R, or Syms: R
24
+ vs = [x isa Pair ? x. first : x for x in sts]
25
+ vs_syms = [x isa Pair ? Symbol (x. second) : getname (x) for x in sts]
26
+ length (vs_syms) == length (unique (vs_syms)) || error (" Variables are not unique." )
27
+
28
+ ps = [x isa Pair ? x. first : x for x in pars]
29
+ ps_syms = [x isa Pair ? Symbol (x. second) : getname (x) for x in pars]
30
+ length (ps_syms) == length (unique (ps_syms)) || error (" Parameters are not unique." )
31
+
32
+ FunctionalAffect (f, vs, vs_syms, ps, ps_syms, ctx)
33
+ end
34
+
35
+ FunctionalAffect (;f, sts, pars, ctx = nothing ) = FunctionalAffect (f, sts, pars, ctx)
36
+
37
+ func (f:: FunctionalAffect ) = f. f
38
+ context (a:: FunctionalAffect ) = a. ctx
39
+ parameters (a:: FunctionalAffect ) = a. pars
40
+ parameters_syms (a:: FunctionalAffect ) = a. pars_syms
41
+ states (a:: FunctionalAffect ) = a. sts
42
+ states_syms (a:: FunctionalAffect ) = a. sts_syms
43
+
44
+ function Base.:(== )(a1:: FunctionalAffect , a2:: FunctionalAffect )
45
+ isequal (a1. f, a2. f) && isequal (a1. sts, a2. sts) && isequal (a1. pars, a2. pars) &&
46
+ isequal (a1. sts_syms, a2. sts_syms) && isequal (a1. pars_syms, a2. pars_syms) &&
47
+ isequal (a1. ctx, a2. ctx)
48
+ end
49
+
50
+ function Base. hash (a:: FunctionalAffect , s:: UInt )
51
+ s = hash (a. f, s)
52
+ s = hash (a. sts, s)
53
+ s = hash (a. sts_syms, s)
54
+ s = hash (a. pars, s)
55
+ s = hash (a. pars_syms, s)
56
+ hash (a. ctx, s)
57
+ end
58
+
59
+ has_functional_affect (cb) = affects (cb) isa FunctionalAffect
60
+
61
+ namespace_affect (affect, s) = namespace_equation (affect, s)
62
+ function namespace_affect (affect:: FunctionalAffect , s)
63
+ FunctionalAffect (func (affect),
64
+ renamespace .((s,), states (affect)),
65
+ states_syms (affect),
66
+ renamespace .((s,), parameters (affect)),
67
+ parameters_syms (affect),
68
+ context (affect))
69
+ end
70
+
12
71
# ################################### continuous events #####################################
13
72
14
73
const NULL_AFFECT = Equation[]
15
74
struct SymbolicContinuousCallback
16
75
eqs:: Vector{Equation}
17
- affect:: Vector{Equation}
76
+ affect
18
77
function SymbolicContinuousCallback (eqs:: Vector{Equation} , affect = NULL_AFFECT)
19
78
new (eqs, affect)
20
79
end # Default affect to nothing
21
80
end
22
81
82
+ SymbolicContinuousCallback (eqs:: Vector{Equation} , affect:: Function ) = SymbolicContinuousCallback (eqs, SymbolicContinuousCallback (affect))
83
+
23
84
function Base.:(== )(e1:: SymbolicContinuousCallback , e2:: SymbolicContinuousCallback )
24
85
isequal (e1. eqs, e2. eqs) && isequal (e1. affect, e2. affect)
25
86
end
@@ -57,22 +118,26 @@ equations(cb::SymbolicContinuousCallback) = cb.eqs
57
118
function equations (cbs:: Vector{<:SymbolicContinuousCallback} )
58
119
reduce (vcat, [equations (cb) for cb in cbs])
59
120
end
60
- affect_equations (cb:: SymbolicContinuousCallback ) = cb. affect
61
- function affect_equations (cbs:: Vector{SymbolicContinuousCallback} )
62
- reduce (vcat, [affect_equations (cb) for cb in cbs])
121
+ affects (cb:: SymbolicContinuousCallback ) = cb. affect
122
+ function affects (cbs:: Vector{SymbolicContinuousCallback} )
123
+ reduce (vcat, [affects (cb) for cb in cbs])
63
124
end
64
- function namespace_equation (cb:: SymbolicContinuousCallback , s):: SymbolicContinuousCallback
125
+
126
+ function namespace_callback (cb:: SymbolicContinuousCallback , s):: SymbolicContinuousCallback
65
127
SymbolicContinuousCallback (namespace_equation .(equations (cb), (s,)),
66
- namespace_equation .( affect_equations (cb), (s,)))
128
+ namespace_affect .( affects (cb), (s,)))
67
129
end
68
130
131
+ cb_add_context (cb:: SymbolicContinuousCallback , s) = SymbolicContinuousCallback (equations (cb), af_add_context (affects (cb), s))
132
+
69
133
function continuous_events (sys:: AbstractSystem )
70
134
obs = get_continuous_events (sys)
71
135
filter (! isempty, obs)
136
+
72
137
systems = get_systems (sys)
73
138
cbs = [obs;
74
139
reduce (vcat,
75
- (map (o -> namespace_equation (o, s), continuous_events (s))
140
+ (map (o -> namespace_callback (o, s), continuous_events (s))
76
141
for s in systems),
77
142
init = SymbolicContinuousCallback[])]
78
143
filter (! isempty, cbs)
81
146
# ################################### continuous events #####################################
82
147
83
148
struct SymbolicDiscreteCallback
149
+ # condition can be one of:
150
+ # TODO : Iterative
151
+ # Δt::Real - Periodic with period Δt
152
+ # Δts::Vector{Real} - events trigger in this times (Preset)
153
+ # condition::Vector{Equation} - event triggered when condition is true
84
154
condition
85
- affects:: Vector{Equation}
155
+ affects
156
+
86
157
function SymbolicDiscreteCallback (condition, affects = NULL_AFFECT)
87
- c = value ( scalarize ( condition) )
88
- a = scalarize (affects)
89
- new (c, a)
158
+ c = scalarize_condition ( condition)
159
+ a = scalarize_affects (affects)
160
+ new (c,a)
90
161
end # Default affect to nothing
91
162
end
92
163
164
+ is_timed_condition (cb) = false
165
+ is_timed_condition (:: R ) where {R<: Real } = true
166
+ is_timed_condition (:: V ) where {V<: AbstractVector } = eltype (V) <: Real
167
+ is_timed_condition (cb:: SymbolicDiscreteCallback ) = is_timed_condition (condition (cb))
168
+
169
+ scalarize_condition (condition) = is_timed_condition (condition) ? condition : value (scalarize (condition))
170
+ namespace_condition (condition, s) = is_timed_condition (condition) ? condition : namespace_expr (condition, s)
171
+
172
+ scalarize_affects (affects) = scalarize (affects)
173
+ scalarize_affects (affects:: Tuple ) = FunctionalAffect (affects... )
174
+ scalarize_affects (affects:: NamedTuple ) = FunctionalAffect (;affects... )
175
+ scalarize_affects (affects:: FunctionalAffect ) = affects
176
+
93
177
SymbolicDiscreteCallback (p:: Pair ) = SymbolicDiscreteCallback (p[1 ], p[2 ])
94
178
SymbolicDiscreteCallback (cb:: SymbolicDiscreteCallback ) = cb # passthrough
95
179
96
180
function Base. show (io:: IO , db:: SymbolicDiscreteCallback )
97
181
println (io, " condition: " , db. condition)
98
182
println (io, " affects:" )
99
- for affect in db. affects
100
- println (io, " " , affect)
183
+ if db. affects isa FunctionalAffect
184
+ # TODO
185
+ println (io, " " , db. affects)
186
+ else
187
+ for affect in db. affects
188
+ println (io, " " , affect)
189
+ end
101
190
end
102
191
end
103
192
@@ -106,21 +195,24 @@ function Base.:(==)(e1::SymbolicDiscreteCallback, e2::SymbolicDiscreteCallback)
106
195
end
107
196
function Base. hash (cb:: SymbolicDiscreteCallback , s:: UInt )
108
197
s = foldr (hash, cb. condition, init = s)
109
- foldr (hash, cb. affects, init = s)
198
+ cb . affects isa AbstractVector ? foldr (hash, cb. affects, init = s) : hash (cb . affects, s)
110
199
end
111
200
112
201
condition (cb:: SymbolicDiscreteCallback ) = cb. condition
113
202
function conditions (cbs:: Vector{<:SymbolicDiscreteCallback} )
114
203
reduce (vcat, condition (cb) for cb in cbs)
115
204
end
116
205
117
- affect_equations (cb:: SymbolicDiscreteCallback ) = cb. affects
118
- function affect_equations (cbs:: Vector{SymbolicDiscreteCallback} )
119
- reduce (vcat, affect_equations (cb) for cb in cbs)
206
+ affects (cb:: SymbolicDiscreteCallback ) = cb. affects
207
+
208
+ function affects (cbs:: Vector{SymbolicDiscreteCallback} )
209
+ reduce (vcat, affects (cb) for cb in cbs)
120
210
end
121
- function namespace_equation (cb:: SymbolicDiscreteCallback , s):: SymbolicDiscreteCallback
122
- SymbolicDiscreteCallback (namespace_expr (condition (cb), s),
123
- namespace_equation .(affect_equations (cb), Ref (s)))
211
+
212
+ function namespace_callback (cb:: SymbolicDiscreteCallback , s):: SymbolicDiscreteCallback
213
+ af = affects (cb)
214
+ af = af isa AbstractVector ? namespace_affect .(af, Ref (s)) : namespace_affect (af, s)
215
+ SymbolicDiscreteCallback (namespace_condition (condition (cb), s), af)
124
216
end
125
217
126
218
SymbolicDiscreteCallbacks (cb:: Pair ) = SymbolicDiscreteCallback[SymbolicDiscreteCallback (cb)]
@@ -134,7 +226,7 @@ function discrete_events(sys::AbstractSystem)
134
226
systems = get_systems (sys)
135
227
cbs = [obs;
136
228
reduce (vcat,
137
- (map (o -> namespace_equation (o, s), discrete_events (s)) for s in systems),
229
+ (map (o -> namespace_callback (o, s), discrete_events (s)) for s in systems),
138
230
init = SymbolicDiscreteCallback[])]
139
231
cbs
140
232
end
@@ -178,7 +270,7 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
178
270
end
179
271
180
272
function compile_affect (cb:: SymbolicContinuousCallback , args... ; kwargs... )
181
- compile_affect (affect_equations (cb), args... ; kwargs... )
273
+ compile_affect (affects (cb), args... ; kwargs... )
182
274
end
183
275
184
276
"""
@@ -208,7 +300,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
208
300
outvar = :u
209
301
if outputidxs === nothing
210
302
lhss = map (x -> x. lhs, eqs)
211
- update_vars = collect (Iterators. flatten (map (ModelingToolkit. vars, lhss))) # these are the ones we're chaning
303
+ update_vars = collect (Iterators. flatten (map (ModelingToolkit. vars, lhss))) # these are the ones we're changing
212
304
length (update_vars) == length (unique (update_vars)) == length (eqs) ||
213
305
error (" affected variables not unique, each state can only be affected by one equation for a single `root_eqs => affects` pair." )
214
306
alleq = all (isequal (isparameter (first (update_vars))),
@@ -303,17 +395,70 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = states
303
395
end
304
396
end
305
397
398
+ function compile_user_affect (affect:: FunctionalAffect , sys, dvs, ps; kwargs... )
399
+ ind (sym, v) = findfirst (isequal (sym), v)
400
+ inds (syms, v) = map (sym -> ind (sym, v), syms)
401
+ v_inds = inds (states (affect), dvs)
402
+ p_inds = inds (parameters (affect), ps)
403
+
404
+ # HACK: filter out eliminated symbols. Not clear this is the right thing to do
405
+ # (MTK should keep these symbols)
406
+ v = filter (x -> ! isnothing (x[1 ]), collect (zip (v_inds, states_syms (affect))))
407
+ v_inds = [x[1 ] for x in v]
408
+ v_syms = Tuple ([x[2 ] for x in v])
409
+ p = filter (x -> ! isnothing (x[1 ]), collect (zip (p_inds, parameters_syms (affect))))
410
+ p_inds = [x[1 ] for x in p]
411
+ p_syms = Tuple ([x[2 ] for x in p])
412
+
413
+ let v_inds= v_inds, p_inds= p_inds, v_syms= v_syms, p_syms= p_syms, user_affect= func (affect), ctx = context (affect)
414
+ function (integ)
415
+ uv = @views integ. u[v_inds]
416
+ pv = @views integ. p[p_inds]
417
+
418
+ u = LArray {v_syms} (uv)
419
+ p = LArray {p_syms} (pv)
420
+
421
+ user_affect (integ. t, u, p, ctx)
422
+ end
423
+ end
424
+ end
425
+
426
+ function compile_affect (affect:: FunctionalAffect , sys, dvs, ps; kwargs... )
427
+ compile_user_affect (affect, sys, dvs, ps; kwargs... )
428
+ end
429
+
430
+ function generate_timed_callback (cb, sys, dvs, ps; kwargs... )
431
+ cond = condition (cb)
432
+ as = compile_affect (affects (cb), sys, dvs, ps; expression = Val{false },
433
+ kwargs... )
434
+ if cond isa AbstractVector
435
+ # Preset Time
436
+ return PresetTimeCallback (cond, as)
437
+ else
438
+ # Periodic
439
+ return PeriodicCallback (as, cond)
440
+ end
441
+ end
442
+
443
+ function generate_discrete_callback (cb, sys, dvs, ps; kwargs... )
444
+ if is_timed_condition (cb)
445
+ return generate_timed_callback (cb, sys, dvs, ps, kwargs... )
446
+ else
447
+ c = compile_condition (cb, sys, dvs, ps; expression= Val{false }, kwargs... )
448
+ as = compile_affect (affects (cb), sys, dvs, ps; expression = Val{false },
449
+ kwargs... )
450
+ return DiscreteCallback (c, as)
451
+ end
452
+ end
453
+
306
454
function generate_discrete_callbacks (sys:: AbstractSystem , dvs = states (sys),
307
455
ps = parameters (sys); kwargs... )
308
456
has_discrete_events (sys) || return nothing
309
457
symcbs = discrete_events (sys)
310
458
isempty (symcbs) && return nothing
311
459
312
460
dbs = map (symcbs) do cb
313
- c = compile_condition (cb, sys, dvs, ps; expression= Val{false }, kwargs... )
314
- as = compile_affect (affect_equations (cb), sys, dvs, ps; expression = Val{false },
315
- kwargs... )
316
- DiscreteCallback (c, as)
461
+ generate_discrete_callback (cb, sys, dvs, ps; kwargs... )
317
462
end
318
463
319
464
dbs
@@ -333,7 +478,7 @@ function process_events(sys; callback = nothing, has_difference = false, kwargs.
333
478
if has_discrete_events (sys)
334
479
discrete_cb = generate_discrete_callbacks (sys; kwargs... )
335
480
else
336
- discrete_cb = nothing
481
+ discrete_cb = []
337
482
end
338
483
difference_cb = has_difference ? generate_difference_cb (sys; kwargs... ) : nothing
339
484
0 commit comments