|
76 | 76 |
|
77 | 77 | ################################# compilation functions ####################################
|
78 | 78 |
|
| 79 | +# handles ensuring that affect! functions work with integrator arguments |
| 80 | +function add_integrator_header() |
| 81 | + integrator = gensym(:MTKIntegrator) |
| 82 | + |
| 83 | + expr -> Func([DestructuredArgs(expr.args, integrator, inds = [:u, :p, :t])], [], |
| 84 | + expr.body), |
| 85 | + expr -> Func([DestructuredArgs(expr.args, integrator, inds = [:u, :u, :p, :t])], [], |
| 86 | + expr.body) |
| 87 | +end |
| 88 | + |
79 | 89 | function compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)
|
80 | 90 | compile_affect(affect_equations(cb), args...; kwargs...)
|
81 | 91 | end
|
82 | 92 |
|
83 | 93 | """
|
84 |
| - compile_affect(eqs::Vector{Equation}, sys, dvs, ps; kwargs...) |
| 94 | + compile_affect(eqs::Vector{Equation}, sys, dvs, ps; expression, outputidxs, kwargs...) |
85 | 95 | compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)
|
86 | 96 |
|
87 |
| -Returns a function that takes an integrator as argument and modifies the state with the affect. |
| 97 | +Returns a function that takes an integrator as argument and modifies the state with the |
| 98 | +affect. The generated function has the signature `affect!(integrator)`. |
| 99 | +
|
| 100 | +Notes |
| 101 | +- `expression = Val{true}`, causes the generated function to be returned as an expression. |
| 102 | + If set to `Val{false}` a `RuntimeGeneratedFunction` will be returned. |
| 103 | +- `outputidxs`, a vector of indices of the output variables. |
| 104 | +- `kwargs` are passed through to `Symbolics.build_function`. |
88 | 105 | """
|
89 |
| -function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; expression = Val{false}, kwargs...) |
| 106 | +function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothing, |
| 107 | + expression = Val{true}, |
| 108 | + kwargs...) |
90 | 109 | if isempty(eqs)
|
91 |
| - return (args...) -> () # We don't do anything in the callback, we're just after the event |
| 110 | + if expression == Val{true} |
| 111 | + return :((args...) -> ()) |
| 112 | + else |
| 113 | + return (args...) -> () # We don't do anything in the callback, we're just after the event |
| 114 | + end |
92 | 115 | else
|
93 | 116 | rhss = map(x -> x.rhs, eqs)
|
94 |
| - lhss = map(x -> x.lhs, eqs) |
95 |
| - update_vars = collect(Iterators.flatten(map(ModelingToolkit.vars, lhss))) # these are the ones we're chaning |
96 |
| - length(update_vars) == length(unique(update_vars)) == length(eqs) || |
97 |
| - error("affected variables not unique, each state can only be affected by one equation for a single `root_eqs => affects` pair.") |
98 |
| - vars = states(sys) |
99 | 117 |
|
100 |
| - u = map(x -> time_varying_as_func(value(x), sys), vars) |
| 118 | + if outputidxs === nothing |
| 119 | + lhss = map(x -> x.lhs, eqs) |
| 120 | + update_vars = collect(Iterators.flatten(map(ModelingToolkit.vars, lhss))) # these are the ones we're chaning |
| 121 | + length(update_vars) == length(unique(update_vars)) == length(eqs) || |
| 122 | + error("affected variables not unique, each state can only be affected by one equation for a single `root_eqs => affects` pair.") |
| 123 | + stateind(sym) = findfirst(isequal(sym), dvs) |
| 124 | + update_inds = stateind.(update_vars) |
| 125 | + else |
| 126 | + update_inds = outputidxs |
| 127 | + end |
| 128 | + |
| 129 | + u = map(x -> time_varying_as_func(value(x), sys), dvs) |
101 | 130 | p = map(x -> time_varying_as_func(value(x), sys), ps)
|
102 | 131 | t = get_iv(sys)
|
103 |
| - # stateind(sym) = findfirst(isequal(sym), vars) |
104 |
| - # update_inds = stateind.(update_vars) |
105 |
| - # rf_oop, rf_ip = build_function(rhss, u, p, t; expression = expression, wrap_code=add_integrator_header(), outputidxs = update_inds, kwargs...) |
106 |
| - # rf_ip |
107 |
| - |
108 |
| - rf_oop, rf_ip = build_function(rhss, u, p, t; expression = expression, kwargs...) |
109 |
| - |
110 |
| - stateind(sym) = findfirst(isequal(sym), vars) |
111 |
| - |
112 |
| - update_inds = stateind.(update_vars) |
113 |
| - let update_inds = update_inds |
114 |
| - function (integ) |
115 |
| - lhs = @views integ.u[update_inds] |
116 |
| - rf_ip(lhs, integ.u, integ.p, integ.t) |
117 |
| - end |
118 |
| - end |
| 132 | + rf_oop, rf_ip = build_function(rhss, u, p, t; expression = expression, |
| 133 | + wrap_code = add_integrator_header(), |
| 134 | + outputidxs = update_inds, |
| 135 | + kwargs...) |
| 136 | + rf_ip |
119 | 137 | end
|
120 | 138 | end
|
121 | 139 |
|
@@ -150,7 +168,7 @@ function generate_rootfinding_callback(cbs, sys::ODESystem, dvs = states(sys),
|
150 | 168 |
|
151 | 169 | affect_functions = map(cbs) do cb # Keep affect function separate
|
152 | 170 | eq_aff = affect_equations(cb)
|
153 |
| - affect = compile_affect(eq_aff, sys, dvs, ps; kwargs...) |
| 171 | + affect = compile_affect(eq_aff, sys, dvs, ps; expression = Val{false}, kwargs...) |
154 | 172 | end
|
155 | 173 |
|
156 | 174 | if length(eqs) == 1
|
|
0 commit comments