Skip to content

Commit 41a57ab

Browse files
committed
rearrange callbacks
1 parent 9cb6a4d commit 41a57ab

File tree

6 files changed

+210
-195
lines changed

6 files changed

+210
-195
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ include("domains.jl")
121121

122122
include("systems/abstractsystem.jl")
123123
include("systems/connectors.jl")
124+
include("systems/callbacks.jl")
124125

125126
include("systems/diffeqs/odesystem.jl")
126127
include("systems/diffeqs/sdesystem.jl")

src/systems/abstractsystem.jl

Lines changed: 14 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -162,61 +162,6 @@ independent_variables(sys::AbstractTimeDependentSystem) = [getfield(sys, :iv)]
162162
independent_variables(sys::AbstractTimeIndependentSystem) = []
163163
independent_variables(sys::AbstractMultivariateSystem) = getfield(sys, :ivs)
164164

165-
const NULL_AFFECT = Equation[]
166-
struct SymbolicContinuousCallback
167-
eqs::Vector{Equation}
168-
affect::Vector{Equation}
169-
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT)
170-
new(eqs, affect)
171-
end # Default affect to nothing
172-
end
173-
174-
function Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback)
175-
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect)
176-
end
177-
Base.isempty(cb::SymbolicContinuousCallback) = isempty(cb.eqs)
178-
function Base.hash(cb::SymbolicContinuousCallback, s::UInt)
179-
s = foldr(hash, cb.eqs, init = s)
180-
foldr(hash, cb.affect, init = s)
181-
end
182-
183-
to_equation_vector(eq::Equation) = [eq]
184-
to_equation_vector(eqs::Vector{Equation}) = eqs
185-
function to_equation_vector(eqs::Vector{Any})
186-
isempty(eqs) || error("This should never happen")
187-
Equation[]
188-
end
189-
190-
function SymbolicContinuousCallback(args...)
191-
SymbolicContinuousCallback(to_equation_vector.(args)...)
192-
end # wrap eq in vector
193-
SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2])
194-
SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough
195-
196-
SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb]
197-
SymbolicContinuousCallbacks(cbs::Vector{<:SymbolicContinuousCallback}) = cbs
198-
SymbolicContinuousCallbacks(cbs::Vector) = SymbolicContinuousCallback.(cbs)
199-
function SymbolicContinuousCallbacks(ve::Vector{Equation})
200-
SymbolicContinuousCallbacks(SymbolicContinuousCallback(ve))
201-
end
202-
function SymbolicContinuousCallbacks(others)
203-
SymbolicContinuousCallbacks(SymbolicContinuousCallback(others))
204-
end
205-
SymbolicContinuousCallbacks(::Nothing) = SymbolicContinuousCallbacks(Equation[])
206-
207-
equations(cb::SymbolicContinuousCallback) = cb.eqs
208-
function equations(cbs::Vector{<:SymbolicContinuousCallback})
209-
reduce(vcat, [equations(cb) for cb in cbs])
210-
end
211-
affect_equations(cb::SymbolicContinuousCallback) = cb.affect
212-
function affect_equations(cbs::Vector{SymbolicContinuousCallback})
213-
reduce(vcat, [affect_equations(cb) for cb in cbs])
214-
end
215-
namespace_equation(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback = SymbolicContinuousCallback(namespace_equation.(equations(cb),
216-
(s,)),
217-
namespace_equation.(affect_equations(cb),
218-
(s,)))
219-
220165
for prop in [:eqs
221166
:noiseeqs
222167
:iv
@@ -507,18 +452,6 @@ function observed(sys::AbstractSystem)
507452
init = Equation[])]
508453
end
509454

510-
function continuous_events(sys::AbstractSystem)
511-
obs = get_continuous_events(sys)
512-
filter(!isempty, obs)
513-
systems = get_systems(sys)
514-
cbs = [obs;
515-
reduce(vcat,
516-
(map(o -> namespace_equation(o, s), continuous_events(s))
517-
for s in systems),
518-
init = SymbolicContinuousCallback[])]
519-
filter(!isempty, cbs)
520-
end
521-
522455
Base.@deprecate default_u0(x) defaults(x) false
523456
Base.@deprecate default_p(x) defaults(x) false
524457
function defaults(sys::AbstractSystem)
@@ -586,6 +519,20 @@ function isaffine(sys::AbstractSystem)
586519
all(isaffine(r, states(sys)) for r in rhs)
587520
end
588521

522+
function time_varying_as_func(x, sys::AbstractTimeDependentSystem)
523+
# if something is not x(t) (the current state)
524+
# but is `x(t-1)` or something like that, pass in `x` as a callable function rather
525+
# than pass in a value in place of x(t).
526+
#
527+
# This is done by just making `x` the argument of the function.
528+
if istree(x) &&
529+
operation(x) isa Sym &&
530+
!(length(arguments(x)) == 1 && isequal(arguments(x)[1], get_iv(sys)))
531+
return operation(x)
532+
end
533+
return x
534+
end
535+
589536
struct AbstractSysToExpr
590537
sys::AbstractSystem
591538
states::Vector

src/systems/callbacks.jl

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
#################################### system operations #####################################
2+
get_continuous_events(sys::AbstractSystem) = Equation[]
3+
get_continuous_events(sys::AbstractODESystem) = getfield(sys, :continuous_events)
4+
has_continuous_events(sys::AbstractSystem) = isdefined(sys, :continuous_events)
5+
6+
7+
#################################### continuous events #####################################
8+
9+
const NULL_AFFECT = Equation[]
10+
struct SymbolicContinuousCallback
11+
eqs::Vector{Equation}
12+
affect::Vector{Equation}
13+
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT)
14+
new(eqs, affect)
15+
end # Default affect to nothing
16+
end
17+
18+
function Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback)
19+
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect)
20+
end
21+
Base.isempty(cb::SymbolicContinuousCallback) = isempty(cb.eqs)
22+
function Base.hash(cb::SymbolicContinuousCallback, s::UInt)
23+
s = foldr(hash, cb.eqs, init = s)
24+
foldr(hash, cb.affect, init = s)
25+
end
26+
27+
to_equation_vector(eq::Equation) = [eq]
28+
to_equation_vector(eqs::Vector{Equation}) = eqs
29+
function to_equation_vector(eqs::Vector{Any})
30+
isempty(eqs) || error("This should never happen")
31+
Equation[]
32+
end
33+
34+
function SymbolicContinuousCallback(args...)
35+
SymbolicContinuousCallback(to_equation_vector.(args)...)
36+
end # wrap eq in vector
37+
SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2])
38+
SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough
39+
40+
SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb]
41+
SymbolicContinuousCallbacks(cbs::Vector{<:SymbolicContinuousCallback}) = cbs
42+
SymbolicContinuousCallbacks(cbs::Vector) = SymbolicContinuousCallback.(cbs)
43+
function SymbolicContinuousCallbacks(ve::Vector{Equation})
44+
SymbolicContinuousCallbacks(SymbolicContinuousCallback(ve))
45+
end
46+
function SymbolicContinuousCallbacks(others)
47+
SymbolicContinuousCallbacks(SymbolicContinuousCallback(others))
48+
end
49+
SymbolicContinuousCallbacks(::Nothing) = SymbolicContinuousCallbacks(Equation[])
50+
51+
equations(cb::SymbolicContinuousCallback) = cb.eqs
52+
function equations(cbs::Vector{<:SymbolicContinuousCallback})
53+
reduce(vcat, [equations(cb) for cb in cbs])
54+
end
55+
affect_equations(cb::SymbolicContinuousCallback) = cb.affect
56+
function affect_equations(cbs::Vector{SymbolicContinuousCallback})
57+
reduce(vcat, [affect_equations(cb) for cb in cbs])
58+
end
59+
namespace_equation(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback = SymbolicContinuousCallback(namespace_equation.(equations(cb),
60+
(s,)),
61+
namespace_equation.(affect_equations(cb),
62+
(s,)))
63+
64+
function continuous_events(sys::AbstractSystem)
65+
obs = get_continuous_events(sys)
66+
filter(!isempty, obs)
67+
systems = get_systems(sys)
68+
cbs = [obs;
69+
reduce(vcat,
70+
(map(o -> namespace_equation(o, s), continuous_events(s))
71+
for s in systems),
72+
init = SymbolicContinuousCallback[])]
73+
filter(!isempty, cbs)
74+
end
75+
76+
77+
################################# compilation functions ####################################
78+
79+
function compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)
80+
compile_affect(affect_equations(cb), args...; kwargs...)
81+
end
82+
83+
"""
84+
compile_affect(eqs::Vector{Equation}, sys, dvs, ps; kwargs...)
85+
compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)
86+
87+
Returns a function that takes an integrator as argument and modifies the state with the affect.
88+
"""
89+
function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; expression = Val{false}, kwargs...)
90+
if isempty(eqs)
91+
return (args...) -> () # We don't do anything in the callback, we're just after the event
92+
else
93+
rhss = map(x -> x.rhs, eqs)
94+
lhss = map(x -> x.lhs, eqs)
95+
update_vars = collect(Iterators.flatten(map(ModelingToolkit.vars, lhss))) # these are the ones we're chaning
96+
length(update_vars) == length(unique(update_vars)) == length(eqs) ||
97+
error("affected variables not unique, each state can only be affected by one equation for a single `root_eqs => affects` pair.")
98+
vars = states(sys)
99+
100+
u = map(x -> time_varying_as_func(value(x), sys), vars)
101+
p = map(x -> time_varying_as_func(value(x), sys), ps)
102+
t = get_iv(sys)
103+
# stateind(sym) = findfirst(isequal(sym), vars)
104+
# update_inds = stateind.(update_vars)
105+
# rf_oop, rf_ip = build_function(rhss, u, p, t; expression = expression, wrap_code=add_integrator_header(), outputidxs = update_inds, kwargs...)
106+
# rf_ip
107+
108+
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = expression, kwargs...)
109+
110+
stateind(sym) = findfirst(isequal(sym), vars)
111+
112+
update_inds = stateind.(update_vars)
113+
let update_inds = update_inds
114+
function (integ)
115+
lhs = @views integ.u[update_inds]
116+
rf_ip(lhs, integ.u, integ.p, integ.t)
117+
end
118+
end
119+
end
120+
end
121+
122+
123+
function generate_rootfinding_callback(sys::ODESystem, dvs = states(sys),
124+
ps = parameters(sys); kwargs...)
125+
cbs = continuous_events(sys)
126+
isempty(cbs) && return nothing
127+
generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
128+
end
129+
130+
function generate_rootfinding_callback(cbs, sys::ODESystem, dvs = states(sys),
131+
ps = parameters(sys); kwargs...)
132+
eqs = map(cb -> cb.eqs, cbs)
133+
num_eqs = length.(eqs)
134+
(isempty(eqs) || sum(num_eqs) == 0) && return nothing
135+
# fuse equations to create VectorContinuousCallback
136+
eqs = reduce(vcat, eqs)
137+
# rewrite all equations as 0 ~ interesting stuff
138+
eqs = map(eqs) do eq
139+
isequal(eq.lhs, 0) && return eq
140+
0 ~ eq.lhs - eq.rhs
141+
end
142+
143+
rhss = map(x -> x.rhs, eqs)
144+
root_eq_vars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))
145+
146+
u = map(x -> time_varying_as_func(value(x), sys), dvs)
147+
p = map(x -> time_varying_as_func(value(x), sys), ps)
148+
t = get_iv(sys)
149+
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = Val{false}, kwargs...)
150+
151+
affect_functions = map(cbs) do cb # Keep affect function separate
152+
eq_aff = affect_equations(cb)
153+
affect = compile_affect(eq_aff, sys, dvs, ps; kwargs...)
154+
end
155+
156+
if length(eqs) == 1
157+
cond = function (u, t, integ)
158+
if DiffEqBase.isinplace(integ.sol.prob)
159+
tmp, = DiffEqBase.get_tmp_cache(integ)
160+
rf_ip(tmp, u, integ.p, t)
161+
tmp[1]
162+
else
163+
rf_oop(u, integ.p, t)
164+
end
165+
end
166+
ContinuousCallback(cond, affect_functions[])
167+
else
168+
cond = function (out, u, t, integ)
169+
rf_ip(out, u, integ.p, t)
170+
end
171+
172+
# since there may be different number of conditions and affects,
173+
# we build a map that translates the condition eq. number to the affect number
174+
eq_ind2affect = reduce(vcat,
175+
[fill(i, num_eqs[i]) for i in eachindex(affect_functions)])
176+
@assert length(eq_ind2affect) == length(eqs)
177+
@assert maximum(eq_ind2affect) == length(affect_functions)
178+
179+
affect = let affect_functions = affect_functions, eq_ind2affect = eq_ind2affect
180+
function (integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations
181+
affect_functions[eq_ind2affect[eq_ind]](integ)
182+
end
183+
end
184+
VectorContinuousCallback(cond, affect, length(eqs))
185+
end
186+
end

0 commit comments

Comments
 (0)