@@ -123,9 +123,10 @@ function namespace_equation(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCa
123
123
namespace_equation .(affect_equations (cb), Ref (s)))
124
124
end
125
125
126
+ SymbolicDiscreteCallbacks (cb:: Pair ) = SymbolicDiscreteCallback[SymbolicDiscreteCallback (cb)]
127
+ SymbolicDiscreteCallbacks (cbs:: Vector ) = SymbolicDiscreteCallback .(cbs)
126
128
SymbolicDiscreteCallbacks (cb:: SymbolicDiscreteCallback ) = [cb]
127
129
SymbolicDiscreteCallbacks (cbs:: Vector{<:SymbolicDiscreteCallback} ) = cbs
128
- SymbolicDiscreteCallbacks (cbs:: Vector ) = SymbolicDiscreteCallback .(cbs)
129
130
SymbolicDiscreteCallbacks (:: Nothing ) = SymbolicDiscreteCallback[]
130
131
131
132
function discrete_events (sys:: AbstractSystem )
@@ -135,19 +136,18 @@ function discrete_events(sys::AbstractSystem)
135
136
reduce (vcat,
136
137
(map (o -> namespace_equation (o, s), discrete_events (s)) for s in systems),
137
138
init = SymbolicDiscreteCallback[])]
138
- filter ( ! isempty, cbs)
139
+ cbs
139
140
end
140
141
141
-
142
142
# ################################ compilation functions ####################################
143
143
144
144
# handles ensuring that affect! functions work with integrator arguments
145
- function add_integrator_header ()
145
+ function add_integrator_header (out = :u )
146
146
integrator = gensym (:MTKIntegrator )
147
147
148
148
expr -> Func ([DestructuredArgs (expr. args, integrator, inds = [:u , :p , :t ])], [],
149
149
expr. body),
150
- expr -> Func ([DestructuredArgs (expr. args, integrator, inds = [:u , :u , :p , :t ])], [],
150
+ expr -> Func ([DestructuredArgs (expr. args, integrator, inds = [out , :u , :p , :t ])], [],
151
151
expr. body)
152
152
end
153
153
@@ -185,7 +185,7 @@ affect. The generated function has the signature `affect!(integrator)`.
185
185
Notes
186
186
- `expression = Val{true}`, causes the generated function to be returned as an expression.
187
187
If set to `Val{false}` a `RuntimeGeneratedFunction` will be returned.
188
- - `outputidxs`, a vector of indices of the output variables .
188
+ - `outputidxs`, a vector of indices of the states that correspond to outputs .
189
189
- `kwargs` are passed through to `Symbolics.build_function`.
190
190
"""
191
191
function compile_affect (eqs:: Vector{Equation} , sys, dvs, ps; outputidxs = nothing ,
@@ -200,13 +200,24 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
200
200
else
201
201
rhss = map (x -> x. rhs, eqs)
202
202
203
+ outvar = :u
203
204
if outputidxs === nothing
204
205
lhss = map (x -> x. lhs, eqs)
205
206
update_vars = collect (Iterators. flatten (map (ModelingToolkit. vars, lhss))) # these are the ones we're chaning
206
207
length (update_vars) == length (unique (update_vars)) == length (eqs) ||
207
208
error (" affected variables not unique, each state can only be affected by one equation for a single `root_eqs => affects` pair." )
208
- stateind (sym) = findfirst (isequal (sym), dvs)
209
- update_inds = stateind .(update_vars)
209
+ alleq = all (isequal (isparameter (first (update_vars))),
210
+ Iterators. map (isparameter, update_vars))
211
+ if ! isparameter (first (lhss)) && alleq
212
+ stateind (sym) = findfirst (isequal (sym), dvs)
213
+ update_inds = stateind .(update_vars)
214
+ elseif isparameter (first (lhss)) && alleq
215
+ psind (sym) = findfirst (isequal (sym), ps)
216
+ update_inds = psind .(update_vars)
217
+ outvar = :p
218
+ else
219
+ error (" Error, building an affect function for a callback that wants to modify both parameters and states. This is not currently allowed in one individual callback." )
220
+ end
210
221
else
211
222
update_inds = outputidxs
212
223
end
@@ -215,7 +226,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
215
226
p = map (x -> time_varying_as_func (value (x), sys), ps)
216
227
t = get_iv (sys)
217
228
rf_oop, rf_ip = build_function (rhss, u, p, t; expression = expression,
218
- wrap_code = add_integrator_header (),
229
+ wrap_code = add_integrator_header (outvar ),
219
230
outputidxs = update_inds,
220
231
kwargs... )
221
232
rf_ip
0 commit comments