Skip to content

Commit 2395a93

Browse files
authored
Merge pull request #1661 from isaacsas/discrete-callbacks
Cleanup callbacks
2 parents ba38a7a + f094b87 commit 2395a93

File tree

6 files changed

+228
-203
lines changed

6 files changed

+228
-203
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ include("domains.jl")
121121

122122
include("systems/abstractsystem.jl")
123123
include("systems/connectors.jl")
124+
include("systems/callbacks.jl")
124125

125126
include("systems/diffeqs/odesystem.jl")
126127
include("systems/diffeqs/sdesystem.jl")

src/systems/abstractsystem.jl

Lines changed: 14 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -162,61 +162,6 @@ independent_variables(sys::AbstractTimeDependentSystem) = [getfield(sys, :iv)]
162162
independent_variables(sys::AbstractTimeIndependentSystem) = []
163163
independent_variables(sys::AbstractMultivariateSystem) = getfield(sys, :ivs)
164164

165-
const NULL_AFFECT = Equation[]
166-
struct SymbolicContinuousCallback
167-
eqs::Vector{Equation}
168-
affect::Vector{Equation}
169-
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT)
170-
new(eqs, affect)
171-
end # Default affect to nothing
172-
end
173-
174-
function Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback)
175-
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect)
176-
end
177-
Base.isempty(cb::SymbolicContinuousCallback) = isempty(cb.eqs)
178-
function Base.hash(cb::SymbolicContinuousCallback, s::UInt)
179-
s = foldr(hash, cb.eqs, init = s)
180-
foldr(hash, cb.affect, init = s)
181-
end
182-
183-
to_equation_vector(eq::Equation) = [eq]
184-
to_equation_vector(eqs::Vector{Equation}) = eqs
185-
function to_equation_vector(eqs::Vector{Any})
186-
isempty(eqs) || error("This should never happen")
187-
Equation[]
188-
end
189-
190-
function SymbolicContinuousCallback(args...)
191-
SymbolicContinuousCallback(to_equation_vector.(args)...)
192-
end # wrap eq in vector
193-
SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2])
194-
SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough
195-
196-
SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb]
197-
SymbolicContinuousCallbacks(cbs::Vector{<:SymbolicContinuousCallback}) = cbs
198-
SymbolicContinuousCallbacks(cbs::Vector) = SymbolicContinuousCallback.(cbs)
199-
function SymbolicContinuousCallbacks(ve::Vector{Equation})
200-
SymbolicContinuousCallbacks(SymbolicContinuousCallback(ve))
201-
end
202-
function SymbolicContinuousCallbacks(others)
203-
SymbolicContinuousCallbacks(SymbolicContinuousCallback(others))
204-
end
205-
SymbolicContinuousCallbacks(::Nothing) = SymbolicContinuousCallbacks(Equation[])
206-
207-
equations(cb::SymbolicContinuousCallback) = cb.eqs
208-
function equations(cbs::Vector{<:SymbolicContinuousCallback})
209-
reduce(vcat, [equations(cb) for cb in cbs])
210-
end
211-
affect_equations(cb::SymbolicContinuousCallback) = cb.affect
212-
function affect_equations(cbs::Vector{SymbolicContinuousCallback})
213-
reduce(vcat, [affect_equations(cb) for cb in cbs])
214-
end
215-
namespace_equation(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback = SymbolicContinuousCallback(namespace_equation.(equations(cb),
216-
(s,)),
217-
namespace_equation.(affect_equations(cb),
218-
(s,)))
219-
220165
for prop in [:eqs
221166
:noiseeqs
222167
:iv
@@ -507,18 +452,6 @@ function observed(sys::AbstractSystem)
507452
init = Equation[])]
508453
end
509454

510-
function continuous_events(sys::AbstractSystem)
511-
obs = get_continuous_events(sys)
512-
filter(!isempty, obs)
513-
systems = get_systems(sys)
514-
cbs = [obs;
515-
reduce(vcat,
516-
(map(o -> namespace_equation(o, s), continuous_events(s))
517-
for s in systems),
518-
init = SymbolicContinuousCallback[])]
519-
filter(!isempty, cbs)
520-
end
521-
522455
Base.@deprecate default_u0(x) defaults(x) false
523456
Base.@deprecate default_p(x) defaults(x) false
524457
function defaults(sys::AbstractSystem)
@@ -586,6 +519,20 @@ function isaffine(sys::AbstractSystem)
586519
all(isaffine(r, states(sys)) for r in rhs)
587520
end
588521

522+
function time_varying_as_func(x, sys::AbstractTimeDependentSystem)
523+
# if something is not x(t) (the current state)
524+
# but is `x(t-1)` or something like that, pass in `x` as a callable function rather
525+
# than pass in a value in place of x(t).
526+
#
527+
# This is done by just making `x` the argument of the function.
528+
if istree(x) &&
529+
operation(x) isa Sym &&
530+
!(length(arguments(x)) == 1 && isequal(arguments(x)[1], get_iv(sys)))
531+
return operation(x)
532+
end
533+
return x
534+
end
535+
589536
struct AbstractSysToExpr
590537
sys::AbstractSystem
591538
states::Vector

src/systems/callbacks.jl

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
#################################### system operations #####################################
2+
get_continuous_events(sys::AbstractSystem) = Equation[]
3+
get_continuous_events(sys::AbstractODESystem) = getfield(sys, :continuous_events)
4+
has_continuous_events(sys::AbstractSystem) = isdefined(sys, :continuous_events)
5+
6+
#################################### continuous events #####################################
7+
8+
const NULL_AFFECT = Equation[]
9+
struct SymbolicContinuousCallback
10+
eqs::Vector{Equation}
11+
affect::Vector{Equation}
12+
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT)
13+
new(eqs, affect)
14+
end # Default affect to nothing
15+
end
16+
17+
function Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback)
18+
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect)
19+
end
20+
Base.isempty(cb::SymbolicContinuousCallback) = isempty(cb.eqs)
21+
function Base.hash(cb::SymbolicContinuousCallback, s::UInt)
22+
s = foldr(hash, cb.eqs, init = s)
23+
foldr(hash, cb.affect, init = s)
24+
end
25+
26+
to_equation_vector(eq::Equation) = [eq]
27+
to_equation_vector(eqs::Vector{Equation}) = eqs
28+
function to_equation_vector(eqs::Vector{Any})
29+
isempty(eqs) || error("This should never happen")
30+
Equation[]
31+
end
32+
33+
function SymbolicContinuousCallback(args...)
34+
SymbolicContinuousCallback(to_equation_vector.(args)...)
35+
end # wrap eq in vector
36+
SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2])
37+
SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough
38+
39+
SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb]
40+
SymbolicContinuousCallbacks(cbs::Vector{<:SymbolicContinuousCallback}) = cbs
41+
SymbolicContinuousCallbacks(cbs::Vector) = SymbolicContinuousCallback.(cbs)
42+
function SymbolicContinuousCallbacks(ve::Vector{Equation})
43+
SymbolicContinuousCallbacks(SymbolicContinuousCallback(ve))
44+
end
45+
function SymbolicContinuousCallbacks(others)
46+
SymbolicContinuousCallbacks(SymbolicContinuousCallback(others))
47+
end
48+
SymbolicContinuousCallbacks(::Nothing) = SymbolicContinuousCallbacks(Equation[])
49+
50+
equations(cb::SymbolicContinuousCallback) = cb.eqs
51+
function equations(cbs::Vector{<:SymbolicContinuousCallback})
52+
reduce(vcat, [equations(cb) for cb in cbs])
53+
end
54+
affect_equations(cb::SymbolicContinuousCallback) = cb.affect
55+
function affect_equations(cbs::Vector{SymbolicContinuousCallback})
56+
reduce(vcat, [affect_equations(cb) for cb in cbs])
57+
end
58+
namespace_equation(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback = SymbolicContinuousCallback(namespace_equation.(equations(cb),
59+
(s,)),
60+
namespace_equation.(affect_equations(cb),
61+
(s,)))
62+
63+
function continuous_events(sys::AbstractSystem)
64+
obs = get_continuous_events(sys)
65+
filter(!isempty, obs)
66+
systems = get_systems(sys)
67+
cbs = [obs;
68+
reduce(vcat,
69+
(map(o -> namespace_equation(o, s), continuous_events(s))
70+
for s in systems),
71+
init = SymbolicContinuousCallback[])]
72+
filter(!isempty, cbs)
73+
end
74+
75+
################################# compilation functions ####################################
76+
77+
# handles ensuring that affect! functions work with integrator arguments
78+
function add_integrator_header()
79+
integrator = gensym(:MTKIntegrator)
80+
81+
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [:u, :p, :t])], [],
82+
expr.body),
83+
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [:u, :u, :p, :t])], [],
84+
expr.body)
85+
end
86+
87+
function compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)
88+
compile_affect(affect_equations(cb), args...; kwargs...)
89+
end
90+
91+
"""
92+
compile_affect(eqs::Vector{Equation}, sys, dvs, ps; expression, outputidxs, kwargs...)
93+
compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)
94+
95+
Returns a function that takes an integrator as argument and modifies the state with the
96+
affect. The generated function has the signature `affect!(integrator)`.
97+
98+
Notes
99+
- `expression = Val{true}`, causes the generated function to be returned as an expression.
100+
If set to `Val{false}` a `RuntimeGeneratedFunction` will be returned.
101+
- `outputidxs`, a vector of indices of the output variables which should correspond to
102+
`states(sys)`. If provided, checks that the LHS of affect equations are variables are
103+
dropped, i.e. it is assumed these indices are correct and affect equations are
104+
well-formed.
105+
- `kwargs` are passed through to `Symbolics.build_function`.
106+
"""
107+
function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothing,
108+
expression = Val{true}, checkvars = true, kwargs...)
109+
if isempty(eqs)
110+
if expression == Val{true}
111+
return :((args...) -> ())
112+
else
113+
return (args...) -> () # We don't do anything in the callback, we're just after the event
114+
end
115+
else
116+
rhss = map(x -> x.rhs, eqs)
117+
118+
if outputidxs === nothing
119+
lhss = map(x -> x.lhs, eqs)
120+
all(isvariable, lhss) ||
121+
error("Non-variable symbolic expression found on the left hand side of an affect equation. Such equations must be of the form variable ~ symbolic expression for the new value of the variable.")
122+
update_vars = collect(Iterators.flatten(map(ModelingToolkit.vars, lhss))) # these are the ones we're chaning
123+
length(update_vars) == length(unique(update_vars)) == length(eqs) ||
124+
error("affected variables not unique, each state can only be affected by one equation for a single `root_eqs => affects` pair.")
125+
stateind(sym) = findfirst(isequal(sym), dvs)
126+
update_inds = stateind.(update_vars)
127+
else
128+
update_inds = outputidxs
129+
end
130+
131+
if checkvars
132+
u = map(x -> time_varying_as_func(value(x), sys), dvs)
133+
p = map(x -> time_varying_as_func(value(x), sys), ps)
134+
else
135+
u = dvs
136+
p = ps
137+
end
138+
t = get_iv(sys)
139+
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = expression,
140+
wrap_code = add_integrator_header(),
141+
outputidxs = update_inds,
142+
kwargs...)
143+
rf_ip
144+
end
145+
end
146+
147+
function generate_rootfinding_callback(sys::AbstractODESystem, dvs = states(sys),
148+
ps = parameters(sys); kwargs...)
149+
cbs = continuous_events(sys)
150+
isempty(cbs) && return nothing
151+
generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
152+
end
153+
154+
function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = states(sys),
155+
ps = parameters(sys); kwargs...)
156+
eqs = map(cb -> cb.eqs, cbs)
157+
num_eqs = length.(eqs)
158+
(isempty(eqs) || sum(num_eqs) == 0) && return nothing
159+
# fuse equations to create VectorContinuousCallback
160+
eqs = reduce(vcat, eqs)
161+
# rewrite all equations as 0 ~ interesting stuff
162+
eqs = map(eqs) do eq
163+
isequal(eq.lhs, 0) && return eq
164+
0 ~ eq.lhs - eq.rhs
165+
end
166+
167+
rhss = map(x -> x.rhs, eqs)
168+
root_eq_vars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))
169+
170+
u = map(x -> time_varying_as_func(value(x), sys), dvs)
171+
p = map(x -> time_varying_as_func(value(x), sys), ps)
172+
t = get_iv(sys)
173+
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = Val{false}, kwargs...)
174+
175+
affect_functions = map(cbs) do cb # Keep affect function separate
176+
eq_aff = affect_equations(cb)
177+
affect = compile_affect(eq_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
178+
end
179+
180+
if length(eqs) == 1
181+
cond = function (u, t, integ)
182+
if DiffEqBase.isinplace(integ.sol.prob)
183+
tmp, = DiffEqBase.get_tmp_cache(integ)
184+
rf_ip(tmp, u, integ.p, t)
185+
tmp[1]
186+
else
187+
rf_oop(u, integ.p, t)
188+
end
189+
end
190+
ContinuousCallback(cond, affect_functions[])
191+
else
192+
cond = function (out, u, t, integ)
193+
rf_ip(out, u, integ.p, t)
194+
end
195+
196+
# since there may be different number of conditions and affects,
197+
# we build a map that translates the condition eq. number to the affect number
198+
eq_ind2affect = reduce(vcat,
199+
[fill(i, num_eqs[i]) for i in eachindex(affect_functions)])
200+
@assert length(eq_ind2affect) == length(eqs)
201+
@assert maximum(eq_ind2affect) == length(affect_functions)
202+
203+
affect = let affect_functions = affect_functions, eq_ind2affect = eq_ind2affect
204+
function (integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations
205+
affect_functions[eq_ind2affect[eq_ind]](integ)
206+
end
207+
end
208+
VectorContinuousCallback(cond, affect, length(eqs))
209+
end
210+
end

0 commit comments

Comments
 (0)