|
| 1 | +#################################### system operations ##################################### |
| 2 | +get_continuous_events(sys::AbstractSystem) = Equation[] |
| 3 | +get_continuous_events(sys::AbstractODESystem) = getfield(sys, :continuous_events) |
| 4 | +has_continuous_events(sys::AbstractSystem) = isdefined(sys, :continuous_events) |
| 5 | + |
| 6 | +#################################### continuous events ##################################### |
| 7 | + |
| 8 | +const NULL_AFFECT = Equation[] |
| 9 | +struct SymbolicContinuousCallback |
| 10 | + eqs::Vector{Equation} |
| 11 | + affect::Vector{Equation} |
| 12 | + function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT) |
| 13 | + new(eqs, affect) |
| 14 | + end # Default affect to nothing |
| 15 | +end |
| 16 | + |
| 17 | +function Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback) |
| 18 | + isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) |
| 19 | +end |
| 20 | +Base.isempty(cb::SymbolicContinuousCallback) = isempty(cb.eqs) |
| 21 | +function Base.hash(cb::SymbolicContinuousCallback, s::UInt) |
| 22 | + s = foldr(hash, cb.eqs, init = s) |
| 23 | + foldr(hash, cb.affect, init = s) |
| 24 | +end |
| 25 | + |
| 26 | +to_equation_vector(eq::Equation) = [eq] |
| 27 | +to_equation_vector(eqs::Vector{Equation}) = eqs |
| 28 | +function to_equation_vector(eqs::Vector{Any}) |
| 29 | + isempty(eqs) || error("This should never happen") |
| 30 | + Equation[] |
| 31 | +end |
| 32 | + |
| 33 | +function SymbolicContinuousCallback(args...) |
| 34 | + SymbolicContinuousCallback(to_equation_vector.(args)...) |
| 35 | +end # wrap eq in vector |
| 36 | +SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2]) |
| 37 | +SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough |
| 38 | + |
| 39 | +SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb] |
| 40 | +SymbolicContinuousCallbacks(cbs::Vector{<:SymbolicContinuousCallback}) = cbs |
| 41 | +SymbolicContinuousCallbacks(cbs::Vector) = SymbolicContinuousCallback.(cbs) |
| 42 | +function SymbolicContinuousCallbacks(ve::Vector{Equation}) |
| 43 | + SymbolicContinuousCallbacks(SymbolicContinuousCallback(ve)) |
| 44 | +end |
| 45 | +function SymbolicContinuousCallbacks(others) |
| 46 | + SymbolicContinuousCallbacks(SymbolicContinuousCallback(others)) |
| 47 | +end |
| 48 | +SymbolicContinuousCallbacks(::Nothing) = SymbolicContinuousCallbacks(Equation[]) |
| 49 | + |
| 50 | +equations(cb::SymbolicContinuousCallback) = cb.eqs |
| 51 | +function equations(cbs::Vector{<:SymbolicContinuousCallback}) |
| 52 | + reduce(vcat, [equations(cb) for cb in cbs]) |
| 53 | +end |
| 54 | +affect_equations(cb::SymbolicContinuousCallback) = cb.affect |
| 55 | +function affect_equations(cbs::Vector{SymbolicContinuousCallback}) |
| 56 | + reduce(vcat, [affect_equations(cb) for cb in cbs]) |
| 57 | +end |
| 58 | +namespace_equation(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback = SymbolicContinuousCallback(namespace_equation.(equations(cb), |
| 59 | + (s,)), |
| 60 | + namespace_equation.(affect_equations(cb), |
| 61 | + (s,))) |
| 62 | + |
| 63 | +function continuous_events(sys::AbstractSystem) |
| 64 | + obs = get_continuous_events(sys) |
| 65 | + filter(!isempty, obs) |
| 66 | + systems = get_systems(sys) |
| 67 | + cbs = [obs; |
| 68 | + reduce(vcat, |
| 69 | + (map(o -> namespace_equation(o, s), continuous_events(s)) |
| 70 | + for s in systems), |
| 71 | + init = SymbolicContinuousCallback[])] |
| 72 | + filter(!isempty, cbs) |
| 73 | +end |
| 74 | + |
| 75 | +################################# compilation functions #################################### |
| 76 | + |
| 77 | +# handles ensuring that affect! functions work with integrator arguments |
| 78 | +function add_integrator_header() |
| 79 | + integrator = gensym(:MTKIntegrator) |
| 80 | + |
| 81 | + expr -> Func([DestructuredArgs(expr.args, integrator, inds = [:u, :p, :t])], [], |
| 82 | + expr.body), |
| 83 | + expr -> Func([DestructuredArgs(expr.args, integrator, inds = [:u, :u, :p, :t])], [], |
| 84 | + expr.body) |
| 85 | +end |
| 86 | + |
| 87 | +function compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...) |
| 88 | + compile_affect(affect_equations(cb), args...; kwargs...) |
| 89 | +end |
| 90 | + |
| 91 | +""" |
| 92 | + compile_affect(eqs::Vector{Equation}, sys, dvs, ps; expression, outputidxs, kwargs...) |
| 93 | + compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...) |
| 94 | +
|
| 95 | +Returns a function that takes an integrator as argument and modifies the state with the |
| 96 | +affect. The generated function has the signature `affect!(integrator)`. |
| 97 | +
|
| 98 | +Notes |
| 99 | +- `expression = Val{true}`, causes the generated function to be returned as an expression. |
| 100 | + If set to `Val{false}` a `RuntimeGeneratedFunction` will be returned. |
| 101 | +- `outputidxs`, a vector of indices of the output variables which should correspond to |
| 102 | + `states(sys)`. If provided, checks that the LHS of affect equations are variables are |
| 103 | + dropped, i.e. it is assumed these indices are correct and affect equations are |
| 104 | + well-formed. |
| 105 | +- `kwargs` are passed through to `Symbolics.build_function`. |
| 106 | +""" |
| 107 | +function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothing, |
| 108 | + expression = Val{true}, checkvars = true, kwargs...) |
| 109 | + if isempty(eqs) |
| 110 | + if expression == Val{true} |
| 111 | + return :((args...) -> ()) |
| 112 | + else |
| 113 | + return (args...) -> () # We don't do anything in the callback, we're just after the event |
| 114 | + end |
| 115 | + else |
| 116 | + rhss = map(x -> x.rhs, eqs) |
| 117 | + |
| 118 | + if outputidxs === nothing |
| 119 | + lhss = map(x -> x.lhs, eqs) |
| 120 | + all(isvariable, lhss) || |
| 121 | + error("Non-variable symbolic expression found on the left hand side of an affect equation. Such equations must be of the form variable ~ symbolic expression for the new value of the variable.") |
| 122 | + update_vars = collect(Iterators.flatten(map(ModelingToolkit.vars, lhss))) # these are the ones we're chaning |
| 123 | + length(update_vars) == length(unique(update_vars)) == length(eqs) || |
| 124 | + error("affected variables not unique, each state can only be affected by one equation for a single `root_eqs => affects` pair.") |
| 125 | + stateind(sym) = findfirst(isequal(sym), dvs) |
| 126 | + update_inds = stateind.(update_vars) |
| 127 | + else |
| 128 | + update_inds = outputidxs |
| 129 | + end |
| 130 | + |
| 131 | + if checkvars |
| 132 | + u = map(x -> time_varying_as_func(value(x), sys), dvs) |
| 133 | + p = map(x -> time_varying_as_func(value(x), sys), ps) |
| 134 | + else |
| 135 | + u = dvs |
| 136 | + p = ps |
| 137 | + end |
| 138 | + t = get_iv(sys) |
| 139 | + rf_oop, rf_ip = build_function(rhss, u, p, t; expression = expression, |
| 140 | + wrap_code = add_integrator_header(), |
| 141 | + outputidxs = update_inds, |
| 142 | + kwargs...) |
| 143 | + rf_ip |
| 144 | + end |
| 145 | +end |
| 146 | + |
| 147 | +function generate_rootfinding_callback(sys::AbstractODESystem, dvs = states(sys), |
| 148 | + ps = parameters(sys); kwargs...) |
| 149 | + cbs = continuous_events(sys) |
| 150 | + isempty(cbs) && return nothing |
| 151 | + generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...) |
| 152 | +end |
| 153 | + |
| 154 | +function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = states(sys), |
| 155 | + ps = parameters(sys); kwargs...) |
| 156 | + eqs = map(cb -> cb.eqs, cbs) |
| 157 | + num_eqs = length.(eqs) |
| 158 | + (isempty(eqs) || sum(num_eqs) == 0) && return nothing |
| 159 | + # fuse equations to create VectorContinuousCallback |
| 160 | + eqs = reduce(vcat, eqs) |
| 161 | + # rewrite all equations as 0 ~ interesting stuff |
| 162 | + eqs = map(eqs) do eq |
| 163 | + isequal(eq.lhs, 0) && return eq |
| 164 | + 0 ~ eq.lhs - eq.rhs |
| 165 | + end |
| 166 | + |
| 167 | + rhss = map(x -> x.rhs, eqs) |
| 168 | + root_eq_vars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss)))) |
| 169 | + |
| 170 | + u = map(x -> time_varying_as_func(value(x), sys), dvs) |
| 171 | + p = map(x -> time_varying_as_func(value(x), sys), ps) |
| 172 | + t = get_iv(sys) |
| 173 | + rf_oop, rf_ip = build_function(rhss, u, p, t; expression = Val{false}, kwargs...) |
| 174 | + |
| 175 | + affect_functions = map(cbs) do cb # Keep affect function separate |
| 176 | + eq_aff = affect_equations(cb) |
| 177 | + affect = compile_affect(eq_aff, sys, dvs, ps; expression = Val{false}, kwargs...) |
| 178 | + end |
| 179 | + |
| 180 | + if length(eqs) == 1 |
| 181 | + cond = function (u, t, integ) |
| 182 | + if DiffEqBase.isinplace(integ.sol.prob) |
| 183 | + tmp, = DiffEqBase.get_tmp_cache(integ) |
| 184 | + rf_ip(tmp, u, integ.p, t) |
| 185 | + tmp[1] |
| 186 | + else |
| 187 | + rf_oop(u, integ.p, t) |
| 188 | + end |
| 189 | + end |
| 190 | + ContinuousCallback(cond, affect_functions[]) |
| 191 | + else |
| 192 | + cond = function (out, u, t, integ) |
| 193 | + rf_ip(out, u, integ.p, t) |
| 194 | + end |
| 195 | + |
| 196 | + # since there may be different number of conditions and affects, |
| 197 | + # we build a map that translates the condition eq. number to the affect number |
| 198 | + eq_ind2affect = reduce(vcat, |
| 199 | + [fill(i, num_eqs[i]) for i in eachindex(affect_functions)]) |
| 200 | + @assert length(eq_ind2affect) == length(eqs) |
| 201 | + @assert maximum(eq_ind2affect) == length(affect_functions) |
| 202 | + |
| 203 | + affect = let affect_functions = affect_functions, eq_ind2affect = eq_ind2affect |
| 204 | + function (integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations |
| 205 | + affect_functions[eq_ind2affect[eq_ind]](integ) |
| 206 | + end |
| 207 | + end |
| 208 | + VectorContinuousCallback(cond, affect, length(eqs)) |
| 209 | + end |
| 210 | +end |
0 commit comments