Skip to content

Commit 52bdc81

Browse files
authored
Include root-finding equations in ODESystem (#1337)
1 parent 8bec990 commit 52bdc81

File tree

11 files changed

+603
-23
lines changed

11 files changed

+603
-23
lines changed

docs/src/basics/AbstractSystem.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Optionally, a system could have:
5656

5757
- `observed(sys)`: All observed equations of the system and its subsystems.
5858
- `get_observed(sys)`: Observed equations of the current-level system.
59+
- `get_continuous_events(sys)`: `SymbolicContinuousCallback`s of the current-level system.
5960
- `get_defaults(sys)`: A `Dict` that maps variables into their default values.
6061
- `independent_variables(sys)`: The independent variables of a system.
6162
- `get_noiseeqs(sys)`: Noise equations of the current-level system.

docs/src/basics/Composition.md

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,94 @@ strongly connected components calculated during the process of simplification
250250
as the basis for building pre-simplified nonlinear systems in the implicit
251251
solving. In summary: these problems are structurally modified, but could be
252252
more efficient and more stable.
253+
254+
## Components with discontinuous dynamics
255+
When modeling, e.g., impacts, saturations or Coulomb friction, the dynamic equations are discontinuous in either the state or one of its derivatives. This causes the solver to take very small steps around the discontinuity, and sometimes leads to early stopping due to `dt <= dt_min`. The correct way to handle such dynamics is to tell the solver about the discontinuity be means of a root-finding equation. [`ODEsystem`](@ref)s accept a keyword argument `continuous_events`
256+
```
257+
ODESystem(eqs, ...; continuous_events::Vector{Equation})
258+
ODESystem(eqs, ...; continuous_events::Pair{Vector{Equation}, Vector{Equation}})
259+
```
260+
where equations can be added that evaluate to 0 at discontinuities.
261+
262+
To model events that have an affect on the state, provide `events::Pair{Vector{Equation}, Vector{Equation}}` where the first entry in the pair is a vector of equations describing event conditions, and the second vector of equations describe the affect on the state. The affect equations must be on the form
263+
```
264+
single_state_variable ~ expression_involving_any_variables
265+
```
266+
267+
### Example: Friction
268+
The system below illustrates how this can be used to model Coulomb friction
269+
```julia
270+
using ModelingToolkit, OrdinaryDiffEq, Plots
271+
function UnitMassWithFriction(k; name)
272+
@variables t x(t)=0 v(t)=0
273+
D = Differential(t)
274+
eqs = [
275+
D(x) ~ v
276+
D(v) ~ sin(t) - k*sign(v) # f = ma, sinusoidal force acting on the mass, and Coulomb friction opposing the movement
277+
]
278+
ODESystem(eqs, t, continuous_events=[v ~ 0], name=name) # when v = 0 there is a discontinuity
279+
end
280+
@named m = UnitMassWithFriction(0.7)
281+
prob = ODEProblem(m, Pair[], (0, 10pi))
282+
sol = solve(prob, Tsit5())
283+
plot(sol)
284+
```
285+
286+
### Example: Bouncing ball
287+
In the documentation for DifferentialEquations, we have an example where a bouncing ball is simulated using callbacks which has an `affect!` on the state. We can model the same system using ModelingToolkit like this
288+
289+
```julia
290+
@variables t x(t)=1 v(t)=0
291+
D = Differential(t)
292+
293+
root_eqs = [x ~ 0] # the event happens at the ground x(t) = 0
294+
affect = [v ~ -v] # the effect is that the velocity changes sign
295+
296+
@named ball = ODESystem([
297+
D(x) ~ v
298+
D(v) ~ -9.8
299+
], t, continuous_events = root_eqs => affect) # equation => affect
300+
301+
ball = structural_simplify(ball)
302+
303+
tspan = (0.0,5.0)
304+
prob = ODEProblem(ball, Pair[], tspan)
305+
sol = solve(prob,Tsit5())
306+
@assert 0 <= minimum(sol[x]) <= 1e-10 # the ball never went through the floor but got very close
307+
plot(sol)
308+
```
309+
310+
### Test bouncing ball in 2D with walls
311+
Multiple events? No problem! This example models a bouncing ball in 2D that is enclosed by two walls at $y = \pm 1.5$.
312+
```julia
313+
@variables t x(t)=1 y(t)=0 vx(t)=0 vy(t)=2
314+
D = Differential(t)
315+
316+
continuous_events = [ # This time we have a vector of pairs
317+
[x ~ 0] => [vx ~ -vx]
318+
[y ~ -1.5, y ~ 1.5] => [vy ~ -vy]
319+
]
320+
321+
@named ball = ODESystem([
322+
D(x) ~ vx,
323+
D(y) ~ vy,
324+
D(vx) ~ -9.8-0.1vx, # gravity + some small air resistance
325+
D(vy) ~ -0.1vy,
326+
], t, continuous_events = continuous_events)
327+
328+
329+
ball = structural_simplify(ball)
330+
331+
tspan = (0.0,10.0)
332+
prob = ODEProblem(ball, Pair[], tspan)
333+
334+
sol = solve(prob,Tsit5())
335+
@assert 0 <= minimum(sol[x]) <= 1e-10 # the ball never went through the floor but got very close
336+
@assert minimum(sol[y]) > -1.5 # check wall conditions
337+
@assert maximum(sol[y]) < 1.5 # check wall conditions
338+
339+
tv = sort([LinRange(0, 10, 200); sol.t])
340+
plot(sol(tv)[y], sol(tv)[x], line_z=tv)
341+
vline!([-1.5, 1.5], l=(:black, 5), primary=false)
342+
hline!([0], l=(:black, 5), primary=false)
343+
```

src/structural_transformation/codegen.jl

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,22 +125,39 @@ function partitions_dag(s::SystemStructure)
125125
sparse(I, J, true, n, n)
126126
end
127127

128-
function gen_nlsolve(sys, eqs, vars; checkbounds=true)
129-
@assert !isempty(vars)
130-
@assert length(eqs) == length(vars)
128+
"""
129+
exprs = gen_nlsolve(eqs::Vector{Equation}, vars::Vector, u0map::Dict; checkbounds = true)
130+
131+
Generate `SymbolicUtils` expressions for a root-finding function based on `eqs`,
132+
as well as a call to the root-finding solver.
133+
134+
`exprs` is a two element vector
135+
```
136+
exprs = [fname = f, numerical_nlsolve(fname, ...)]
137+
```
138+
139+
# Arguments:
140+
- `eqs`: Equations to find roots of.
141+
- `vars`: ???
142+
- `u0map`: A `Dict` which maps variables in `eqs` to values, e.g., `defaults(sys)` if `eqs = equations(sys)`.
143+
- `checkbounds`: Apply bounds checking in the generated code.
144+
"""
145+
function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true)
146+
isempty(vars) && throw(ArgumentError("vars may not be empty"))
147+
length(eqs) == length(vars) || throw(ArgumentError("vars must be of the same length as the number of equations to find the roots of"))
131148
rhss = map(x->x.rhs, eqs)
132149
# We use `vars` instead of `graph` to capture parameters, too.
133150
allvars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))
134-
params = setdiff(allvars, vars)
151+
params = setdiff(allvars, vars) # these are not the subject of the root finding
135152

136-
u0map = defaults(sys)
137153
# splatting to tighten the type
138154
u0 = [map(var->get(u0map, var, 1e-3), vars)...]
139155
# specialize on the scalar case
140156
isscalar = length(u0) == 1
141157
u0 = isscalar ? u0[1] : SVector(u0...)
142158

143159
fname = gensym("fun")
160+
# f is the function to find roots on
144161
f = Func(
145162
[
146163
DestructuredArgs(vars, inbounds=!checkbounds)
@@ -150,6 +167,7 @@ function gen_nlsolve(sys, eqs, vars; checkbounds=true)
150167
isscalar ? rhss[1] : MakeArray(rhss, SVector)
151168
) |> SymbolicUtils.Code.toexpr
152169

170+
# solver call contains code to call the root-finding solver on the function f
153171
solver_call = LiteralExpr(quote
154172
$numerical_nlsolve(
155173
$fname,
@@ -174,8 +192,9 @@ function get_torn_eqs_vars(sys; checkbounds=true)
174192

175193
torn_eqs = map(idxs-> eqs[idxs], map(x->x.e_residual, partitions))
176194
torn_vars = map(idxs->vars[idxs], map(x->x.v_residual, partitions))
195+
u0map = defaults(sys)
177196

178-
gen_nlsolve.((sys,), torn_eqs, torn_vars, checkbounds=checkbounds)
197+
gen_nlsolve.(torn_eqs, torn_vars, (u0map,), checkbounds=checkbounds)
179198
end
180199

181200
function build_torn_function(
@@ -308,8 +327,8 @@ function build_observed_function(
308327

309328
torn_eqs = map(idxs-> eqs[idxs.e_residual], subset)
310329
torn_vars = map(idxs->fullvars[idxs.v_residual], subset)
311-
312-
solves = gen_nlsolve.((sys,), torn_eqs, torn_vars; checkbounds=checkbounds)
330+
u0map = defaults(sys)
331+
solves = gen_nlsolve.(torn_eqs, torn_vars, (u0map,); checkbounds=checkbounds)
313332
else
314333
solves = []
315334
end

src/systems/abstractsystem.jl

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,40 @@ independent_variables(sys::AbstractTimeDependentSystem) = [getfield(sys, :iv)]
154154
independent_variables(sys::AbstractTimeIndependentSystem) = []
155155
independent_variables(sys::AbstractMultivariateSystem) = getfield(sys, :ivs)
156156

157+
const NULL_AFFECT = Equation[]
158+
struct SymbolicContinuousCallback
159+
eqs::Vector{Equation}
160+
affect::Vector{Equation}
161+
SymbolicContinuousCallback(eqs::Vector{Equation}, affect=NULL_AFFECT) = new(eqs, affect) # Default affect to nothing
162+
end
163+
164+
Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback) = isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect)
165+
166+
to_equation_vector(eq::Equation) = [eq]
167+
to_equation_vector(eqs::Vector{Equation}) = eqs
168+
function to_equation_vector(eqs::Vector{Any})
169+
isempty(eqs) || error("This should never happen")
170+
Equation[]
171+
end
172+
173+
SymbolicContinuousCallback(args...) = SymbolicContinuousCallback(to_equation_vector.(args)...) # wrap eq in vector
174+
SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2])
175+
SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough
176+
177+
SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb]
178+
SymbolicContinuousCallbacks(cbs::Vector{<:SymbolicContinuousCallback}) = cbs
179+
SymbolicContinuousCallbacks(cbs::Vector) = SymbolicContinuousCallback.(cbs)
180+
SymbolicContinuousCallbacks(ve::Vector{Equation}) = SymbolicContinuousCallbacks(SymbolicContinuousCallback(ve))
181+
SymbolicContinuousCallbacks(others) = SymbolicContinuousCallbacks(SymbolicContinuousCallback(others))
182+
SymbolicContinuousCallbacks(::Nothing) = SymbolicContinuousCallbacks(Equation[])
183+
184+
equations(cb::SymbolicContinuousCallback) = cb.eqs
185+
equations(cbs::Vector{<:SymbolicContinuousCallback}) = reduce(vcat, [equations(cb) for cb in cbs])
186+
affect_equations(cb::SymbolicContinuousCallback) = cb.affect
187+
affect_equations(cbs::Vector{SymbolicContinuousCallback}) = reduce(vcat, [affect_equations(cb) for cb in cbs])
188+
namespace_equation(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback = SymbolicContinuousCallback(namespace_equation.(equations(cb), (s, )), namespace_equation.(affect_equations(cb), (s, )))
189+
190+
157191
function structure(sys::AbstractSystem)
158192
s = get_structure(sys)
159193
s isa SystemStructure || throw(ArgumentError("SystemStructure is not yet initialized, please run `sys = initialize_system_structure(sys)` or `sys = alias_elimination(sys)`."))
@@ -415,6 +449,15 @@ function observed(sys::AbstractSystem)
415449
init=Equation[])]
416450
end
417451

452+
function continuous_events(sys::AbstractSystem)
453+
obs = get_continuous_events(sys)
454+
systems = get_systems(sys)
455+
[obs;
456+
reduce(vcat,
457+
(map(o->namespace_equation(o, s), continuous_events(s)) for s in systems),
458+
init=SymbolicContinuousCallback[])]
459+
end
460+
418461
Base.@deprecate default_u0(x) defaults(x) false
419462
Base.@deprecate default_p(x) defaults(x) false
420463
function defaults(sys::AbstractSystem)
@@ -941,6 +984,7 @@ function Base.hash(sys::AbstractSystem, s::UInt)
941984
s = foldr(hash, get_eqs(sys), init=s)
942985
end
943986
s = foldr(hash, get_observed(sys), init=s)
987+
s = foldr(hash, get_continuous_events(sys), init=s)
944988
s = hash(independent_variables(sys), s)
945989
return s
946990
end
@@ -968,13 +1012,14 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem; name::Symbol=nameo
9681012
sts = union(get_states(basesys), get_states(sys))
9691013
ps = union(get_ps(basesys), get_ps(sys))
9701014
obs = union(get_observed(basesys), get_observed(sys))
1015+
evs = union(get_continuous_events(basesys), get_continuous_events(sys))
9711016
defs = merge(get_defaults(basesys), get_defaults(sys)) # prefer `sys`
9721017
syss = union(get_systems(basesys), get_systems(sys))
9731018

9741019
if length(ivs) == 0
975-
T(eqs, sts, ps, observed = obs, defaults = defs, name=name, systems = syss)
1020+
T(eqs, sts, ps, observed = obs, defaults = defs, name=name, systems = syss, continuous_events=evs)
9761021
elseif length(ivs) == 1
977-
T(eqs, ivs[1], sts, ps, observed = obs, defaults = defs, name = name, systems = syss)
1022+
T(eqs, ivs[1], sts, ps, observed = obs, defaults = defs, name = name, systems = syss, continuous_events=evs)
9781023
end
9791024
end
9801025

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 116 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,105 @@ function generate_difference_cb(sys::ODESystem, dvs = states(sys), ps = paramete
153153
PeriodicCallback(cb_affect!, first(dt))
154154
end
155155

156+
function generate_rootfinding_callback(sys::ODESystem, dvs = states(sys), ps = parameters(sys); kwargs...)
157+
cbs = continuous_events(sys)
158+
isempty(cbs) && return nothing
159+
generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
160+
end
161+
162+
function generate_rootfinding_callback(cbs, sys::ODESystem, dvs = states(sys), ps = parameters(sys); kwargs...)
163+
eqs = map(cb->cb.eqs, cbs)
164+
num_eqs = length.(eqs)
165+
(isempty(eqs) || sum(num_eqs) == 0) && return nothing
166+
# fuse equations to create VectorContinuousCallback
167+
eqs = reduce(vcat, eqs)
168+
# rewrite all equations as 0 ~ interesting stuff
169+
eqs = map(eqs) do eq
170+
isequal(eq.lhs, 0) && return eq
171+
0 ~ eq.lhs - eq.rhs
172+
end
173+
174+
rhss = map(x->x.rhs, eqs)
175+
root_eq_vars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))
176+
177+
u = map(x->time_varying_as_func(value(x), sys), dvs)
178+
p = map(x->time_varying_as_func(value(x), sys), ps)
179+
t = get_iv(sys)
180+
rf_oop, rf_ip = build_function(rhss, u, p, t; expression=Val{false}, kwargs...)
181+
182+
affect_functions = map(cbs) do cb # Keep affect function separate
183+
eq_aff = affect_equations(cb)
184+
affect = compile_affect(eq_aff, sys, dvs, ps; kwargs...)
185+
end
186+
187+
if length(eqs) == 1
188+
cond = function(u, t, integ)
189+
if DiffEqBase.isinplace(integ.sol.prob)
190+
tmp, = DiffEqBase.get_tmp_cache(integ)
191+
rf_ip(tmp, u, integ.p, t)
192+
tmp[1]
193+
else
194+
rf_oop(u, integ.p, t)
195+
end
196+
end
197+
ContinuousCallback(cond, affect_functions[])
198+
else
199+
cond = function(out, u, t, integ)
200+
rf_ip(out, u, integ.p, t)
201+
end
202+
203+
# since there may be different number of conditions and affects,
204+
# we build a map that translates the condition eq. number to the affect number
205+
eq_ind2affect = reduce(vcat, [fill(i, num_eqs[i]) for i in eachindex(affect_functions)])
206+
@assert length(eq_ind2affect) == length(eqs)
207+
@assert maximum(eq_ind2affect) == length(affect_functions)
208+
209+
affect = let affect_functions=affect_functions, eq_ind2affect=eq_ind2affect
210+
function(integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations
211+
affect_functions[eq_ind2affect[eq_ind]](integ)
212+
end
213+
end
214+
VectorContinuousCallback(cond, affect, length(eqs))
215+
end
216+
end
217+
218+
compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...) = compile_affect(affect_equations(cb), args...; kwargs...)
219+
220+
"""
221+
compile_affect(eqs::Vector{Equation}, sys, dvs, ps; kwargs...)
222+
compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)
223+
224+
Returns a function that takes an integrator as argument and modifies the state with the affect.
225+
"""
226+
function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; kwargs...)
227+
if isempty(eqs)
228+
return (args...) -> () # We don't do anything in the callback, we're just after the event
229+
else
230+
rhss = map(x->x.rhs, eqs)
231+
lhss = map(x->x.lhs, eqs)
232+
update_vars = collect(Iterators.flatten(map(ModelingToolkit.vars, lhss))) # these are the ones we're chaning
233+
length(update_vars) == length(unique(update_vars)) == length(eqs) ||
234+
error("affected variables not unique, each state can only be affected by one equation for a single `root_eqs => affects` pair.")
235+
vars = states(sys)
236+
237+
u = map(x->time_varying_as_func(value(x), sys), vars)
238+
p = map(x->time_varying_as_func(value(x), sys), ps)
239+
t = get_iv(sys)
240+
rf_oop, rf_ip = build_function(rhss, u, p, t; expression=Val{false}, kwargs...)
241+
242+
stateind(sym) = findfirst(isequal(sym),vars)
243+
244+
update_inds = stateind.(update_vars)
245+
let update_inds=update_inds
246+
function(integ)
247+
lhs = @views integ.u[update_inds]
248+
rf_ip(lhs, integ.u, integ.p, integ.t)
249+
end
250+
end
251+
end
252+
end
253+
254+
156255
function time_varying_as_func(x, sys::AbstractTimeDependentSystem)
157256
# if something is not x(t) (the current state)
158257
# but is `x(t-1)` or something like that, pass in `x` as a callable function rather
@@ -552,15 +651,28 @@ Generates an ODEProblem from an ODESystem and allows for automatically
552651
symbolically calculating numerical enhancements.
553652
"""
554653
function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
555-
parammap=DiffEqBase.NullParameters();kwargs...) where iip
654+
parammap=DiffEqBase.NullParameters(); callback=nothing, kwargs...) where iip
556655
has_difference = any(isdifferenceeq, equations(sys))
557656
f, u0, p = process_DEProblem(ODEFunction{iip}, sys, u0map, parammap; has_difference=has_difference, kwargs...)
558-
if has_difference
559-
ODEProblem{iip}(f,u0,tspan,p;difference_cb=generate_difference_cb(sys;kwargs...),kwargs...)
657+
if has_continuous_events(sys)
658+
event_cb = generate_rootfinding_callback(sys; kwargs...)
659+
else
660+
event_cb = nothing
661+
end
662+
difference_cb = has_difference ? generate_difference_cb(sys; kwargs...) : nothing
663+
cb = merge_cb(event_cb, difference_cb)
664+
cb = merge_cb(cb, callback)
665+
666+
if cb === nothing
667+
ODEProblem{iip}(f, u0, tspan, p; kwargs...)
560668
else
561-
ODEProblem{iip}(f,u0,tspan,p;kwargs...)
669+
ODEProblem{iip}(f, u0, tspan, p; callback=cb, kwargs...)
562670
end
563671
end
672+
merge_cb(::Nothing, ::Nothing) = nothing
673+
merge_cb(::Nothing, x) = merge_cb(x, nothing)
674+
merge_cb(x, ::Nothing) = x
675+
merge_cb(x, y) = CallbackSet(x, y)
564676

565677
"""
566678
```julia

0 commit comments

Comments
 (0)