|
| 1 | +#################################### system operations ##################################### |
| 2 | +get_continuous_events(sys::AbstractSystem) = Equation[] |
| 3 | +get_continuous_events(sys::AbstractODESystem) = getfield(sys, :continuous_events) |
| 4 | +has_continuous_events(sys::AbstractSystem) = isdefined(sys, :continuous_events) |
| 5 | + |
| 6 | + |
| 7 | +#################################### continuous events ##################################### |
| 8 | + |
| 9 | +const NULL_AFFECT = Equation[] |
| 10 | +struct SymbolicContinuousCallback |
| 11 | + eqs::Vector{Equation} |
| 12 | + affect::Vector{Equation} |
| 13 | + function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT) |
| 14 | + new(eqs, affect) |
| 15 | + end # Default affect to nothing |
| 16 | +end |
| 17 | + |
| 18 | +function Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback) |
| 19 | + isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) |
| 20 | +end |
| 21 | +Base.isempty(cb::SymbolicContinuousCallback) = isempty(cb.eqs) |
| 22 | +function Base.hash(cb::SymbolicContinuousCallback, s::UInt) |
| 23 | + s = foldr(hash, cb.eqs, init = s) |
| 24 | + foldr(hash, cb.affect, init = s) |
| 25 | +end |
| 26 | + |
| 27 | +to_equation_vector(eq::Equation) = [eq] |
| 28 | +to_equation_vector(eqs::Vector{Equation}) = eqs |
| 29 | +function to_equation_vector(eqs::Vector{Any}) |
| 30 | + isempty(eqs) || error("This should never happen") |
| 31 | + Equation[] |
| 32 | +end |
| 33 | + |
| 34 | +function SymbolicContinuousCallback(args...) |
| 35 | + SymbolicContinuousCallback(to_equation_vector.(args)...) |
| 36 | +end # wrap eq in vector |
| 37 | +SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2]) |
| 38 | +SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough |
| 39 | + |
| 40 | +SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb] |
| 41 | +SymbolicContinuousCallbacks(cbs::Vector{<:SymbolicContinuousCallback}) = cbs |
| 42 | +SymbolicContinuousCallbacks(cbs::Vector) = SymbolicContinuousCallback.(cbs) |
| 43 | +function SymbolicContinuousCallbacks(ve::Vector{Equation}) |
| 44 | + SymbolicContinuousCallbacks(SymbolicContinuousCallback(ve)) |
| 45 | +end |
| 46 | +function SymbolicContinuousCallbacks(others) |
| 47 | + SymbolicContinuousCallbacks(SymbolicContinuousCallback(others)) |
| 48 | +end |
| 49 | +SymbolicContinuousCallbacks(::Nothing) = SymbolicContinuousCallbacks(Equation[]) |
| 50 | + |
| 51 | +equations(cb::SymbolicContinuousCallback) = cb.eqs |
| 52 | +function equations(cbs::Vector{<:SymbolicContinuousCallback}) |
| 53 | + reduce(vcat, [equations(cb) for cb in cbs]) |
| 54 | +end |
| 55 | +affect_equations(cb::SymbolicContinuousCallback) = cb.affect |
| 56 | +function affect_equations(cbs::Vector{SymbolicContinuousCallback}) |
| 57 | + reduce(vcat, [affect_equations(cb) for cb in cbs]) |
| 58 | +end |
| 59 | +namespace_equation(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback = SymbolicContinuousCallback(namespace_equation.(equations(cb), |
| 60 | + (s,)), |
| 61 | + namespace_equation.(affect_equations(cb), |
| 62 | + (s,))) |
| 63 | + |
| 64 | +function continuous_events(sys::AbstractSystem) |
| 65 | + obs = get_continuous_events(sys) |
| 66 | + filter(!isempty, obs) |
| 67 | + systems = get_systems(sys) |
| 68 | + cbs = [obs; |
| 69 | + reduce(vcat, |
| 70 | + (map(o -> namespace_equation(o, s), continuous_events(s)) |
| 71 | + for s in systems), |
| 72 | + init = SymbolicContinuousCallback[])] |
| 73 | + filter(!isempty, cbs) |
| 74 | +end |
| 75 | + |
| 76 | + |
| 77 | +################################# compilation functions #################################### |
| 78 | + |
| 79 | +function compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...) |
| 80 | + compile_affect(affect_equations(cb), args...; kwargs...) |
| 81 | +end |
| 82 | + |
| 83 | +""" |
| 84 | + compile_affect(eqs::Vector{Equation}, sys, dvs, ps; kwargs...) |
| 85 | + compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...) |
| 86 | +
|
| 87 | +Returns a function that takes an integrator as argument and modifies the state with the affect. |
| 88 | +""" |
| 89 | +function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; expression = Val{false}, kwargs...) |
| 90 | + if isempty(eqs) |
| 91 | + return (args...) -> () # We don't do anything in the callback, we're just after the event |
| 92 | + else |
| 93 | + rhss = map(x -> x.rhs, eqs) |
| 94 | + lhss = map(x -> x.lhs, eqs) |
| 95 | + update_vars = collect(Iterators.flatten(map(ModelingToolkit.vars, lhss))) # these are the ones we're chaning |
| 96 | + length(update_vars) == length(unique(update_vars)) == length(eqs) || |
| 97 | + error("affected variables not unique, each state can only be affected by one equation for a single `root_eqs => affects` pair.") |
| 98 | + vars = states(sys) |
| 99 | + |
| 100 | + u = map(x -> time_varying_as_func(value(x), sys), vars) |
| 101 | + p = map(x -> time_varying_as_func(value(x), sys), ps) |
| 102 | + t = get_iv(sys) |
| 103 | + # stateind(sym) = findfirst(isequal(sym), vars) |
| 104 | + # update_inds = stateind.(update_vars) |
| 105 | + # rf_oop, rf_ip = build_function(rhss, u, p, t; expression = expression, wrap_code=add_integrator_header(), outputidxs = update_inds, kwargs...) |
| 106 | + # rf_ip |
| 107 | + |
| 108 | + rf_oop, rf_ip = build_function(rhss, u, p, t; expression = expression, kwargs...) |
| 109 | + |
| 110 | + stateind(sym) = findfirst(isequal(sym), vars) |
| 111 | + |
| 112 | + update_inds = stateind.(update_vars) |
| 113 | + let update_inds = update_inds |
| 114 | + function (integ) |
| 115 | + lhs = @views integ.u[update_inds] |
| 116 | + rf_ip(lhs, integ.u, integ.p, integ.t) |
| 117 | + end |
| 118 | + end |
| 119 | + end |
| 120 | +end |
| 121 | + |
| 122 | + |
| 123 | +function generate_rootfinding_callback(sys::ODESystem, dvs = states(sys), |
| 124 | + ps = parameters(sys); kwargs...) |
| 125 | + cbs = continuous_events(sys) |
| 126 | + isempty(cbs) && return nothing |
| 127 | + generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...) |
| 128 | +end |
| 129 | + |
| 130 | +function generate_rootfinding_callback(cbs, sys::ODESystem, dvs = states(sys), |
| 131 | + ps = parameters(sys); kwargs...) |
| 132 | + eqs = map(cb -> cb.eqs, cbs) |
| 133 | + num_eqs = length.(eqs) |
| 134 | + (isempty(eqs) || sum(num_eqs) == 0) && return nothing |
| 135 | + # fuse equations to create VectorContinuousCallback |
| 136 | + eqs = reduce(vcat, eqs) |
| 137 | + # rewrite all equations as 0 ~ interesting stuff |
| 138 | + eqs = map(eqs) do eq |
| 139 | + isequal(eq.lhs, 0) && return eq |
| 140 | + 0 ~ eq.lhs - eq.rhs |
| 141 | + end |
| 142 | + |
| 143 | + rhss = map(x -> x.rhs, eqs) |
| 144 | + root_eq_vars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss)))) |
| 145 | + |
| 146 | + u = map(x -> time_varying_as_func(value(x), sys), dvs) |
| 147 | + p = map(x -> time_varying_as_func(value(x), sys), ps) |
| 148 | + t = get_iv(sys) |
| 149 | + rf_oop, rf_ip = build_function(rhss, u, p, t; expression = Val{false}, kwargs...) |
| 150 | + |
| 151 | + affect_functions = map(cbs) do cb # Keep affect function separate |
| 152 | + eq_aff = affect_equations(cb) |
| 153 | + affect = compile_affect(eq_aff, sys, dvs, ps; kwargs...) |
| 154 | + end |
| 155 | + |
| 156 | + if length(eqs) == 1 |
| 157 | + cond = function (u, t, integ) |
| 158 | + if DiffEqBase.isinplace(integ.sol.prob) |
| 159 | + tmp, = DiffEqBase.get_tmp_cache(integ) |
| 160 | + rf_ip(tmp, u, integ.p, t) |
| 161 | + tmp[1] |
| 162 | + else |
| 163 | + rf_oop(u, integ.p, t) |
| 164 | + end |
| 165 | + end |
| 166 | + ContinuousCallback(cond, affect_functions[]) |
| 167 | + else |
| 168 | + cond = function (out, u, t, integ) |
| 169 | + rf_ip(out, u, integ.p, t) |
| 170 | + end |
| 171 | + |
| 172 | + # since there may be different number of conditions and affects, |
| 173 | + # we build a map that translates the condition eq. number to the affect number |
| 174 | + eq_ind2affect = reduce(vcat, |
| 175 | + [fill(i, num_eqs[i]) for i in eachindex(affect_functions)]) |
| 176 | + @assert length(eq_ind2affect) == length(eqs) |
| 177 | + @assert maximum(eq_ind2affect) == length(affect_functions) |
| 178 | + |
| 179 | + affect = let affect_functions = affect_functions, eq_ind2affect = eq_ind2affect |
| 180 | + function (integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations |
| 181 | + affect_functions[eq_ind2affect[eq_ind]](integ) |
| 182 | + end |
| 183 | + end |
| 184 | + VectorContinuousCallback(cond, affect, length(eqs)) |
| 185 | + end |
| 186 | +end |
0 commit comments