Skip to content

Commit 0164062

Browse files
committed
Support more of the SciMLBase events API
1 parent c2e6e4a commit 0164062

File tree

2 files changed

+375
-41
lines changed

2 files changed

+375
-41
lines changed

src/systems/callbacks.jl

Lines changed: 147 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -76,24 +76,59 @@ end
7676
#################################### continuous events #####################################
7777

7878
const NULL_AFFECT = Equation[]
79+
"""
80+
SymbolicContinuousCallback(eqs::Vector{Equation}, affect, affect_neg, rootfind)
81+
82+
A [`ContinuousCallback`](@ref SciMLBase.ContinuousCallback) specified symbolically. Takes a vector of equations `eq`
83+
as well as the positive-edge `affect` and negative-edge `affect_neg` that apply when *any* of `eq` are satisfied.
84+
By default `affect_neg = affect`; to only get rising edges specify `affect_neg = nothing`.
85+
86+
Assume without loss of generality that the equation is of the form `c(u,p,t) ~ 0`; we denote the integrator state as `i.u`.
87+
For simplicty, we define `prev_sign = sign(c(u[t-1], p[t-1], t-1))` and `cur_sign = sign(c(u[t], p[t], t))`.
88+
A condition edge will be detected and the callback will be invoked iff `prev_sign * cur_sign <= 0`.
89+
Inter-sample condition activation is not guaranteed; for example if we use the dirac delta function as `c` to insert a
90+
sharp discontinuity between integrator steps (which in this example would not normally be identified by adaptivity) then the condition is not
91+
gauranteed to be triggered.
92+
93+
Once detected the integrator will "wind back" through a root-finding process to identify the point when the condition became active; the method used
94+
is specified by `rootfind` from [`SciMLBase.RootfindOpt`](@ref). Multiple callbacks in the same system with different `rootfind` operations will be resolved
95+
into separate VectorContinuousCallbacks in the enumeration order of `SciMLBase.RootfindOpt`, which may cause some callbacks to not fire if several become
96+
active at the same instant. See the `SciMLBase` documentation for more information on the semantic rules.
97+
98+
The positive edge `affect` will be triggered iff an edge is detected and if `prev_sign < 0`; similarly, `affect_neg` will be
99+
triggered iff an edge is detected `prev_sign > 0`.
100+
101+
Affects (i.e. `affect` and `affect_neg`) can be specified as either:
102+
* A list of equations that should be applied when the callback is triggered (e.g. `x ~ 3, y ~ 7`) which must be of the form `unknown ~ observed value` where each `unknown` appears only once. Equations will be applied in the order that they appear in the vector; parameters and state updates will become immediately visible to following equations.
103+
* A tuple `(f!, unknowns, read_parameters, modified_parameters, ctx)`, where:
104+
+ `f!` is a function with signature `(integ, u, p, ctx)` that is called with the integrator, a state *index* vector `u` derived from `unknowns`, a parameter *index* vector `p` derived from `read_parameters`, and the `ctx` that was given at construction time. Note that `ctx` is aliased between instances.
105+
+ `unknowns` is a vector of symbolic unknown variables and optionally their aliases (e.g. if the model was defined with `@variables x(t)` then a valid value for `unknowns` would be `[x]`). A variable can be aliased with a pair `x => :y`. The indices of these `unknowns` will be passed to `f!` in `u` in a named tuple; in the earlier example, if we pass `[x]` as `unknowns` then `f!` can access `x` as `integ.u[u.x]`. If no alias is specified the name of the index will be the symbol version of the variable name.
106+
+ `read_parameters` is a vector of the parameters that are *used* by `f!`. Their indices are passed to `f` in `p` similarly to the indices of `unknowns` passed in `u`.
107+
+ `modified_parameters` is a vector of the parameters that are *modified* by `f!`. Note that a parameter will not appear in `p` if it only appears in `modified_parameters`; it must appear in both `parameters` and `modified_parameters` if it is used in the affect definition.
108+
+ `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem.
109+
"""
79110
struct SymbolicContinuousCallback
80111
eqs::Vector{Equation}
81112
affect::Union{Vector{Equation}, FunctionalAffect}
82-
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT)
83-
new(eqs, make_affect(affect))
113+
affect_neg::Union{Vector{Equation}, FunctionalAffect, Nothing}
114+
rootfind::SciMLBase.RootfindOpt
115+
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT, affect_neg = affect, rootfind=SciMLBase.LeftRootFind)
116+
new(eqs, make_affect(affect), make_affect(affect_neg), rootfind)
84117
end # Default affect to nothing
85118
end
86119
make_affect(affect) = affect
87120
make_affect(affect::Tuple) = FunctionalAffect(affect...)
88121
make_affect(affect::NamedTuple) = FunctionalAffect(; affect...)
89122

90123
function Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback)
91-
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect)
124+
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) && isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind)
92125
end
93126
Base.isempty(cb::SymbolicContinuousCallback) = isempty(cb.eqs)
94127
function Base.hash(cb::SymbolicContinuousCallback, s::UInt)
95128
s = foldr(hash, cb.eqs, init = s)
96-
cb.affect isa AbstractVector ? foldr(hash, cb.affect, init = s) : hash(cb.affect, s)
129+
s = cb.affect isa AbstractVector ? foldr(hash, cb.affect, init = s) : hash(cb.affect, s)
130+
s = cb.affect_neg isa AbstractVector ? foldr(hash, cb.affect_neg, init = s) : hash(cb.affect_neg, s)
131+
hash(cb.rootfind, s)
97132
end
98133

99134
to_equation_vector(eq::Equation) = [eq]
@@ -108,6 +143,8 @@ function SymbolicContinuousCallback(args...)
108143
end # wrap eq in vector
109144
SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2])
110145
SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough
146+
SymbolicContinuousCallback(eqs::Equation, affect = NULL_AFFECT; affect_neg = affect, rootfind=SciMLBase.LeftRootFind) = SymbolicContinuousCallback([eqs], affect, affect_neg, rootfind)
147+
SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT; affect_neg = affect, rootfind=SciMLBase.LeftRootFind) = SymbolicContinuousCallback(eqs, affect, affect_neg, rootfind)
111148

112149
SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb]
113150
SymbolicContinuousCallbacks(cbs::Vector{<:SymbolicContinuousCallback}) = cbs
@@ -130,12 +167,20 @@ function affects(cbs::Vector{SymbolicContinuousCallback})
130167
mapreduce(affects, vcat, cbs, init = Equation[])
131168
end
132169

170+
affect_negs(cb::SymbolicContinuousCallback) = cb.affect_neg
171+
function affect_negs(cbs::Vector{SymbolicContinuousCallback})
172+
mapreduce(affect_negs, vcat, cbs, init = Equation[])
173+
end
174+
133175
namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af]
134176
namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s)
177+
namespace_affects(::Nothing, s) = nothing
135178

136179
function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback
137-
SymbolicContinuousCallback(namespace_equation.(equations(cb), (s,)),
138-
namespace_affects(affects(cb), s))
180+
SymbolicContinuousCallback(
181+
namespace_equation.(equations(cb), (s,)),
182+
namespace_affects(affects(cb), s),
183+
namespace_affects(affect_negs(cb), s))
139184
end
140185

141186
"""
@@ -159,7 +204,7 @@ function continuous_events(sys::AbstractSystem)
159204
filter(!isempty, cbs)
160205
end
161206

162-
#################################### continuous events #####################################
207+
#################################### discrete events #####################################
163208

164209
struct SymbolicDiscreteCallback
165210
# condition can be one of:
@@ -461,12 +506,34 @@ function generate_rootfinding_callback(sys::AbstractODESystem, dvs = unknowns(sy
461506
isempty(cbs) && return nothing
462507
generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
463508
end
509+
"""
510+
Generate a single rootfinding callback; this happens if there is only one equation in `cbs` passed to
511+
generate_rootfinding_callback and thus we can produce a ContinuousCallback instead of a VectorContinuousCallback.
512+
"""
513+
function generate_single_rootfinding_callback(eq, cb, sys::AbstractODESystem, dvs = unknowns(sys),
514+
ps = full_parameters(sys); kwargs...)
515+
if !isequal(eq.lhs, 0)
516+
eq = 0 ~ eq.lhs - eq.rhs
517+
end
518+
519+
rf_oop, rf_ip = generate_custom_function(sys, [eq.rhs], dvs, ps; expression = Val{false}, kwargs...)
520+
affect_function = compile_affect_fn(cb, sys, dvs, ps, kwargs)
521+
cond = function (u, t, integ)
522+
if DiffEqBase.isinplace(integ.sol.prob)
523+
tmp, = DiffEqBase.get_tmp_cache(integ)
524+
rf_ip(tmp, u, parameter_values(integ), t)
525+
tmp[1]
526+
else
527+
rf_oop(u, parameter_values(integ), t)
528+
end
529+
end
530+
return ContinuousCallback(cond, affect_function.affect, affect_function.affect_neg, rootfind=cb.rootfind)
531+
end
464532

465-
function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
466-
ps = full_parameters(sys); kwargs...)
533+
function generate_vector_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
534+
ps = full_parameters(sys); rootfind=SciMLBase.RightRootFind, kwargs...)
467535
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
468536
num_eqs = length.(eqs)
469-
(isempty(eqs) || sum(num_eqs) == 0) && return nothing
470537
# fuse equations to create VectorContinuousCallback
471538
eqs = reduce(vcat, eqs)
472539
# rewrite all equations as 0 ~ interesting stuff
@@ -476,45 +543,85 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
476543
end
477544

478545
rhss = map(x -> x.rhs, eqs)
479-
root_eq_vars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))
480-
481-
rf_oop, rf_ip = generate_custom_function(sys, rhss, dvs, ps; expression = Val{false},
482-
kwargs...)
546+
_, rf_ip = generate_custom_function(sys, rhss, dvs, ps; expression = Val{false}, kwargs...)
483547

484-
affect_functions = map(cbs) do cb # Keep affect function separate
485-
eq_aff = affects(cb)
486-
affect = compile_affect(eq_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
548+
affect_functions = @NamedTuple{affect::Function, affect_neg::Union{Function, Nothing}}[compile_affect_fn(cb, sys, dvs, ps, kwargs) for cb in cbs]
549+
cond = function (out, u, t, integ)
550+
rf_ip(out, u, parameter_values(integ), t)
487551
end
488552

489-
if length(eqs) == 1
490-
cond = function (u, t, integ)
491-
if DiffEqBase.isinplace(integ.sol.prob)
492-
tmp, = DiffEqBase.get_tmp_cache(integ)
493-
rf_ip(tmp, u, parameter_values(integ), t)
494-
tmp[1]
495-
else
496-
rf_oop(u, parameter_values(integ), t)
553+
# since there may be different number of conditions and affects,
554+
# we build a map that translates the condition eq. number to the affect number
555+
eq_ind2affect = reduce(vcat,
556+
[fill(i, num_eqs[i]) for i in eachindex(affect_functions)])
557+
@assert length(eq_ind2affect) == length(eqs)
558+
@assert maximum(eq_ind2affect) == length(affect_functions)
559+
560+
affect = let affect_functions = affect_functions, eq_ind2affect = eq_ind2affect
561+
function (integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations
562+
affect_functions[eq_ind2affect[eq_ind]].affect(integ)
563+
end
564+
end
565+
affect_neg = let affect_functions = affect_functions, eq_ind2affect = eq_ind2affect
566+
function (integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations
567+
affect_neg = affect_functions[eq_ind2affect[eq_ind]].affect_neg
568+
if isnothing(affect_neg)
569+
return # skip if the neg function doesn't exist - don't want to split this into a separate VCC because that'd break ordering
497570
end
571+
affect_neg(integ)
498572
end
499-
ContinuousCallback(cond, affect_functions[])
573+
end
574+
return VectorContinuousCallback(cond, affect, affect_neg, length(eqs), rootfind=rootfind)
575+
end
576+
577+
"""
578+
Compile a single continous callback affect function(s).
579+
"""
580+
function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
581+
eq_aff = affects(cb)
582+
eq_neg_aff = affect_negs(cb)
583+
affect = compile_affect(eq_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
584+
if eq_neg_aff === eq_aff
585+
affect_neg = affect
586+
elseif isnothing(eq_neg_aff)
587+
affect_neg = nothing
500588
else
501-
cond = function (out, u, t, integ)
502-
rf_ip(out, u, parameter_values(integ), t)
589+
affect_neg = compile_affect(eq_neg_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
590+
end
591+
(affect=affect, affect_neg=affect_neg)
592+
end
593+
594+
function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
595+
ps = full_parameters(sys); kwargs...)
596+
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
597+
num_eqs = length.(eqs)
598+
total_eqs = sum(num_eqs)
599+
(isempty(eqs) || total_eqs == 0) && return nothing
600+
if total_eqs == 1
601+
# find the callback with only one eq
602+
cb_ind = findfirst(>(0), num_eqs)
603+
if isnothing(cb_ind)
604+
error("Inconsistent state in affect compilation; one equation but no callback with equations?")
503605
end
606+
cb = cbs[cb_ind]
607+
return generate_single_rootfinding_callback(cb.eqs[], cb, sys, dvs, ps; kwargs...)
608+
end
504609

505-
# since there may be different number of conditions and affects,
506-
# we build a map that translates the condition eq. number to the affect number
507-
eq_ind2affect = reduce(vcat,
508-
[fill(i, num_eqs[i]) for i in eachindex(affect_functions)])
509-
@assert length(eq_ind2affect) == length(eqs)
510-
@assert maximum(eq_ind2affect) == length(affect_functions)
610+
# group the cbs by what rootfind op they use
611+
# groupby would be very useful here, but alas
612+
cb_classes = Dict{@NamedTuple{rootfind::SciMLBase.RootfindOpt}, Vector{SymbolicContinuousCallback}}()
613+
for cb in cbs
614+
push!(get!(() -> SymbolicContinuousCallback[], cb_classes, (rootfind=cb.rootfind, )), cb)
615+
end
511616

512-
affect = let affect_functions = affect_functions, eq_ind2affect = eq_ind2affect
513-
function (integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations
514-
affect_functions[eq_ind2affect[eq_ind]](integ)
515-
end
516-
end
517-
VectorContinuousCallback(cond, affect, length(eqs))
617+
# generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order
618+
compiled_callbacks = map(collect(pairs(sort!(OrderedDict(cb_classes); by=p->p.rootfind)))) do (equiv_class, cbs_in_class)
619+
return generate_vector_rootfinding_callback(cbs_in_class, sys, dvs, ps; rootfind=equiv_class.rootfind, kwargs...)
620+
end
621+
if length(compiled_callbacks) == 1
622+
return compiled_callbacks[]
623+
else
624+
return CallbackSet(compiled_callbacks...)
518625
end
519626
end
520627

@@ -528,7 +635,6 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
528635
ps_ind = Dict(reverse(en) for en in enumerate(ps))
529636
p_inds = map(sym -> ps_ind[sym], parameters(affect))
530637
end
531-
532638
# HACK: filter out eliminated symbols. Not clear this is the right thing to do
533639
# (MTK should keep these symbols)
534640
u = filter(x -> !isnothing(x[2]), collect(zip(unknowns_syms(affect), v_inds))) |>

0 commit comments

Comments
 (0)