Skip to content

Commit 8e7ef41

Browse files
author
dd
committed
added support for periodic & preset-time callbacks, generalized affect functions
1 parent c534a06 commit 8e7ef41

File tree

1 file changed

+174
-29
lines changed

1 file changed

+174
-29
lines changed

src/systems/callbacks.jl

Lines changed: 174 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,84 @@ get_continuous_events(sys::AbstractSystem) = Equation[]
33
get_continuous_events(sys::AbstractODESystem) = getfield(sys, :continuous_events)
44
has_continuous_events(sys::AbstractSystem) = isdefined(sys, :continuous_events)
55

6-
has_discrete_events(sys::AbstractSystem) = isdefined(sys, :discrete_events)
6+
has_discrete_events(sys::AbstractSystem) = isdefined(sys, :discrete_events) && length(sys.discrete_events) > 0
77
function get_discrete_events(sys::AbstractSystem)
88
has_discrete_events(sys) || return SymbolicDiscreteCallback[]
99
getfield(sys, :discrete_events)
1010
end
1111

12+
struct FunctionalAffect
13+
f::Function
14+
sts::Vector
15+
sts_syms::Vector{Symbol}
16+
pars::Vector
17+
pars_syms::Vector{Symbol}
18+
ctx
19+
FunctionalAffect(f, sts, sts_syms, pars, pars_syms, ctx = nothing) = new(f,sts, sts_syms, pars, pars_syms, ctx)
20+
end
21+
22+
function FunctionalAffect(f, sts, pars, ctx = nothing)
23+
# sts & pars contain either pairs: resistor.R => R, or Syms: R
24+
vs = [x isa Pair ? x.first : x for x in sts]
25+
vs_syms = [x isa Pair ? Symbol(x.second) : getname(x) for x in sts]
26+
length(vs_syms) == length(unique(vs_syms)) || error("Variables are not unique.")
27+
28+
ps = [x isa Pair ? x.first : x for x in pars]
29+
ps_syms = [x isa Pair ? Symbol(x.second) : getname(x) for x in pars]
30+
length(ps_syms) == length(unique(ps_syms)) || error("Parameters are not unique.")
31+
32+
FunctionalAffect(f, vs, vs_syms, ps, ps_syms, ctx)
33+
end
34+
35+
FunctionalAffect(;f, sts, pars, ctx = nothing) = FunctionalAffect(f, sts, pars, ctx)
36+
37+
func(f::FunctionalAffect) = f.f
38+
context(a::FunctionalAffect) = a.ctx
39+
parameters(a::FunctionalAffect) = a.pars
40+
parameters_syms(a::FunctionalAffect) = a.pars_syms
41+
states(a::FunctionalAffect) = a.sts
42+
states_syms(a::FunctionalAffect) = a.sts_syms
43+
44+
function Base.:(==)(a1::FunctionalAffect, a2::FunctionalAffect)
45+
isequal(a1.f, a2.f) && isequal(a1.sts, a2.sts) && isequal(a1.pars, a2.pars) &&
46+
isequal(a1.sts_syms, a2.sts_syms) && isequal(a1.pars_syms, a2.pars_syms) &&
47+
isequal(a1.ctx, a2.ctx)
48+
end
49+
50+
function Base.hash(a::FunctionalAffect, s::UInt)
51+
s = hash(a.f, s)
52+
s = hash(a.sts, s)
53+
s = hash(a.sts_syms, s)
54+
s = hash(a.pars, s)
55+
s = hash(a.pars_syms, s)
56+
hash(a.ctx, s)
57+
end
58+
59+
has_functional_affect(cb) = affects(cb) isa FunctionalAffect
60+
61+
namespace_affect(affect, s) = namespace_equation(affect, s)
62+
function namespace_affect(affect::FunctionalAffect, s)
63+
FunctionalAffect(func(affect),
64+
renamespace.((s,), states(affect)),
65+
states_syms(affect),
66+
renamespace.((s,), parameters(affect)),
67+
parameters_syms(affect),
68+
context(affect))
69+
end
70+
1271
#################################### continuous events #####################################
1372

1473
const NULL_AFFECT = Equation[]
1574
struct SymbolicContinuousCallback
1675
eqs::Vector{Equation}
17-
affect::Vector{Equation}
76+
affect
1877
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT)
1978
new(eqs, affect)
2079
end # Default affect to nothing
2180
end
2281

82+
SymbolicContinuousCallback(eqs::Vector{Equation}, affect::Function) = SymbolicContinuousCallback(eqs, SymbolicContinuousCallback(affect))
83+
2384
function Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback)
2485
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect)
2586
end
@@ -57,22 +118,26 @@ equations(cb::SymbolicContinuousCallback) = cb.eqs
57118
function equations(cbs::Vector{<:SymbolicContinuousCallback})
58119
reduce(vcat, [equations(cb) for cb in cbs])
59120
end
60-
affect_equations(cb::SymbolicContinuousCallback) = cb.affect
61-
function affect_equations(cbs::Vector{SymbolicContinuousCallback})
62-
reduce(vcat, [affect_equations(cb) for cb in cbs])
121+
affects(cb::SymbolicContinuousCallback) = cb.affect
122+
function affects(cbs::Vector{SymbolicContinuousCallback})
123+
reduce(vcat, [affects(cb) for cb in cbs])
63124
end
64-
function namespace_equation(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback
125+
126+
function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback
65127
SymbolicContinuousCallback(namespace_equation.(equations(cb), (s,)),
66-
namespace_equation.(affect_equations(cb), (s,)))
128+
namespace_affect.(affects(cb), (s,)))
67129
end
68130

131+
cb_add_context(cb::SymbolicContinuousCallback, s) = SymbolicContinuousCallback(equations(cb), af_add_context(affects(cb), s))
132+
69133
function continuous_events(sys::AbstractSystem)
70134
obs = get_continuous_events(sys)
71135
filter(!isempty, obs)
136+
72137
systems = get_systems(sys)
73138
cbs = [obs;
74139
reduce(vcat,
75-
(map(o -> namespace_equation(o, s), continuous_events(s))
140+
(map(o -> namespace_callback(o, s), continuous_events(s))
76141
for s in systems),
77142
init = SymbolicContinuousCallback[])]
78143
filter(!isempty, cbs)
@@ -81,23 +146,47 @@ end
81146
#################################### continuous events #####################################
82147

83148
struct SymbolicDiscreteCallback
149+
# condition can be one of:
150+
# TODO: Iterative
151+
# Δt::Real - Periodic with period Δt
152+
# Δts::Vector{Real} - events trigger in this times (Preset)
153+
# condition::Vector{Equation} - event triggered when condition is true
84154
condition
85-
affects::Vector{Equation}
155+
affects
156+
86157
function SymbolicDiscreteCallback(condition, affects = NULL_AFFECT)
87-
c = value(scalarize(condition))
88-
a = scalarize(affects)
89-
new(c, a)
158+
c = scalarize_condition(condition)
159+
a = scalarize_affects(affects)
160+
new(c,a)
90161
end # Default affect to nothing
91162
end
92163

164+
is_timed_condition(cb) = false
165+
is_timed_condition(::R) where {R<:Real} = true
166+
is_timed_condition(::V) where {V<:AbstractVector} = eltype(V) <: Real
167+
is_timed_condition(cb::SymbolicDiscreteCallback) = is_timed_condition(condition(cb))
168+
169+
scalarize_condition(condition) = is_timed_condition(condition) ? condition : value(scalarize(condition))
170+
namespace_condition(condition, s) = is_timed_condition(condition) ? condition : namespace_expr(condition, s)
171+
172+
scalarize_affects(affects) = scalarize(affects)
173+
scalarize_affects(affects::Tuple) = FunctionalAffect(affects...)
174+
scalarize_affects(affects::NamedTuple) = FunctionalAffect(;affects...)
175+
scalarize_affects(affects::FunctionalAffect) = affects
176+
93177
SymbolicDiscreteCallback(p::Pair) = SymbolicDiscreteCallback(p[1], p[2])
94178
SymbolicDiscreteCallback(cb::SymbolicDiscreteCallback) = cb # passthrough
95179

96180
function Base.show(io::IO, db::SymbolicDiscreteCallback)
97181
println(io, "condition: ", db.condition)
98182
println(io, "affects:")
99-
for affect in db.affects
100-
println(io, " ", affect)
183+
if db.affects isa FunctionalAffect
184+
# TODO
185+
println(io, " ", db.affects)
186+
else
187+
for affect in db.affects
188+
println(io, " ", affect)
189+
end
101190
end
102191
end
103192

@@ -106,21 +195,24 @@ function Base.:(==)(e1::SymbolicDiscreteCallback, e2::SymbolicDiscreteCallback)
106195
end
107196
function Base.hash(cb::SymbolicDiscreteCallback, s::UInt)
108197
s = foldr(hash, cb.condition, init = s)
109-
foldr(hash, cb.affects, init = s)
198+
cb.affects isa AbstractVector ? foldr(hash, cb.affects, init = s) : hash(cb.affects, s)
110199
end
111200

112201
condition(cb::SymbolicDiscreteCallback) = cb.condition
113202
function conditions(cbs::Vector{<:SymbolicDiscreteCallback})
114203
reduce(vcat, condition(cb) for cb in cbs)
115204
end
116205

117-
affect_equations(cb::SymbolicDiscreteCallback) = cb.affects
118-
function affect_equations(cbs::Vector{SymbolicDiscreteCallback})
119-
reduce(vcat, affect_equations(cb) for cb in cbs)
206+
affects(cb::SymbolicDiscreteCallback) = cb.affects
207+
208+
function affects(cbs::Vector{SymbolicDiscreteCallback})
209+
reduce(vcat, affects(cb) for cb in cbs)
120210
end
121-
function namespace_equation(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCallback
122-
SymbolicDiscreteCallback(namespace_expr(condition(cb), s),
123-
namespace_equation.(affect_equations(cb), Ref(s)))
211+
212+
function namespace_callback(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCallback
213+
af = affects(cb)
214+
af = af isa AbstractVector ? namespace_affect.(af, Ref(s)) : namespace_affect(af, s)
215+
SymbolicDiscreteCallback(namespace_condition(condition(cb), s), af)
124216
end
125217

126218
SymbolicDiscreteCallbacks(cb::Pair) = SymbolicDiscreteCallback[SymbolicDiscreteCallback(cb)]
@@ -134,7 +226,7 @@ function discrete_events(sys::AbstractSystem)
134226
systems = get_systems(sys)
135227
cbs = [obs;
136228
reduce(vcat,
137-
(map(o -> namespace_equation(o, s), discrete_events(s)) for s in systems),
229+
(map(o -> namespace_callback(o, s), discrete_events(s)) for s in systems),
138230
init = SymbolicDiscreteCallback[])]
139231
cbs
140232
end
@@ -178,7 +270,7 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
178270
end
179271

180272
function compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)
181-
compile_affect(affect_equations(cb), args...; kwargs...)
273+
compile_affect(affects(cb), args...; kwargs...)
182274
end
183275

184276
"""
@@ -208,7 +300,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
208300
outvar = :u
209301
if outputidxs === nothing
210302
lhss = map(x -> x.lhs, eqs)
211-
update_vars = collect(Iterators.flatten(map(ModelingToolkit.vars, lhss))) # these are the ones we're chaning
303+
update_vars = collect(Iterators.flatten(map(ModelingToolkit.vars, lhss))) # these are the ones we're changing
212304
length(update_vars) == length(unique(update_vars)) == length(eqs) ||
213305
error("affected variables not unique, each state can only be affected by one equation for a single `root_eqs => affects` pair.")
214306
alleq = all(isequal(isparameter(first(update_vars))),
@@ -303,17 +395,70 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = states
303395
end
304396
end
305397

398+
function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
399+
ind(sym, v) = findfirst(isequal(sym), v)
400+
inds(syms, v) = map(sym -> ind(sym, v), syms)
401+
v_inds = inds(states(affect), dvs)
402+
p_inds = inds(parameters(affect), ps)
403+
404+
# HACK: filter out eliminated symbols. Not clear this is the right thing to do
405+
# (MTK should keep these symbols)
406+
v = filter(x -> !isnothing(x[1]), collect(zip(v_inds, states_syms(affect))))
407+
v_inds = [x[1] for x in v]
408+
v_syms = Tuple([x[2] for x in v])
409+
p = filter(x -> !isnothing(x[1]), collect(zip(p_inds, parameters_syms(affect))))
410+
p_inds = [x[1] for x in p]
411+
p_syms = Tuple([x[2] for x in p])
412+
413+
let v_inds=v_inds, p_inds=p_inds, v_syms=v_syms, p_syms=p_syms, user_affect=func(affect), ctx = context(affect)
414+
function (integ)
415+
uv = @views integ.u[v_inds]
416+
pv = @views integ.p[p_inds]
417+
418+
u = LArray{v_syms}(uv)
419+
p = LArray{p_syms}(pv)
420+
421+
user_affect(integ.t, u, p, ctx)
422+
end
423+
end
424+
end
425+
426+
function compile_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
427+
compile_user_affect(affect, sys, dvs, ps; kwargs...)
428+
end
429+
430+
function generate_timed_callback(cb, sys, dvs, ps; kwargs...)
431+
cond = condition(cb)
432+
as = compile_affect(affects(cb), sys, dvs, ps; expression = Val{false},
433+
kwargs...)
434+
if cond isa AbstractVector
435+
# Preset Time
436+
return PresetTimeCallback(cond, as)
437+
else
438+
# Periodic
439+
return PeriodicCallback(as, cond)
440+
end
441+
end
442+
443+
function generate_discrete_callback(cb, sys, dvs, ps; kwargs...)
444+
if is_timed_condition(cb)
445+
return generate_timed_callback(cb, sys, dvs, ps, kwargs...)
446+
else
447+
c = compile_condition(cb, sys, dvs, ps; expression=Val{false}, kwargs...)
448+
as = compile_affect(affects(cb), sys, dvs, ps; expression = Val{false},
449+
kwargs...)
450+
return DiscreteCallback(c, as)
451+
end
452+
end
453+
306454
function generate_discrete_callbacks(sys::AbstractSystem, dvs = states(sys),
307455
ps = parameters(sys); kwargs...)
308456
has_discrete_events(sys) || return nothing
309457
symcbs = discrete_events(sys)
310458
isempty(symcbs) && return nothing
311459

312460
dbs = map(symcbs) do cb
313-
c = compile_condition(cb, sys, dvs, ps; expression=Val{false}, kwargs...)
314-
as = compile_affect(affect_equations(cb), sys, dvs, ps; expression = Val{false},
315-
kwargs...)
316-
DiscreteCallback(c, as)
461+
generate_discrete_callback(cb, sys, dvs, ps; kwargs...)
317462
end
318463

319464
dbs
@@ -333,7 +478,7 @@ function process_events(sys; callback = nothing, has_difference = false, kwargs.
333478
if has_discrete_events(sys)
334479
discrete_cb = generate_discrete_callbacks(sys; kwargs...)
335480
else
336-
discrete_cb = nothing
481+
discrete_cb = []
337482
end
338483
difference_cb = has_difference ? generate_difference_cb(sys; kwargs...) : nothing
339484

0 commit comments

Comments
 (0)