Skip to content

Commit af11f9d

Browse files
committed
try to allow changing ps
1 parent 7e7eacf commit af11f9d

File tree

2 files changed

+42
-9
lines changed

2 files changed

+42
-9
lines changed

src/systems/callbacks.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,10 @@ function namespace_equation(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCa
123123
namespace_equation.(affect_equations(cb), Ref(s)))
124124
end
125125

126+
SymbolicDiscreteCallbacks(cb::Pair) = SymbolicDiscreteCallback[SymbolicDiscreteCallback(cb)]
127+
SymbolicDiscreteCallbacks(cbs::Vector) = SymbolicDiscreteCallback.(cbs)
126128
SymbolicDiscreteCallbacks(cb::SymbolicDiscreteCallback) = [cb]
127129
SymbolicDiscreteCallbacks(cbs::Vector{<:SymbolicDiscreteCallback}) = cbs
128-
SymbolicDiscreteCallbacks(cbs::Vector) = SymbolicDiscreteCallback.(cbs)
129130
SymbolicDiscreteCallbacks(::Nothing) = SymbolicDiscreteCallback[]
130131

131132
function discrete_events(sys::AbstractSystem)
@@ -135,19 +136,18 @@ function discrete_events(sys::AbstractSystem)
135136
reduce(vcat,
136137
(map(o -> namespace_equation(o, s), discrete_events(s)) for s in systems),
137138
init = SymbolicDiscreteCallback[])]
138-
filter(!isempty, cbs)
139+
cbs
139140
end
140141

141-
142142
################################# compilation functions ####################################
143143

144144
# handles ensuring that affect! functions work with integrator arguments
145-
function add_integrator_header()
145+
function add_integrator_header(out=:u)
146146
integrator = gensym(:MTKIntegrator)
147147

148148
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [:u, :p, :t])], [],
149149
expr.body),
150-
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [:u, :u, :p, :t])], [],
150+
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [out, :u, :p, :t])], [],
151151
expr.body)
152152
end
153153

@@ -185,7 +185,7 @@ affect. The generated function has the signature `affect!(integrator)`.
185185
Notes
186186
- `expression = Val{true}`, causes the generated function to be returned as an expression.
187187
If set to `Val{false}` a `RuntimeGeneratedFunction` will be returned.
188-
- `outputidxs`, a vector of indices of the output variables.
188+
- `outputidxs`, a vector of indices of the states that correspond to outputs.
189189
- `kwargs` are passed through to `Symbolics.build_function`.
190190
"""
191191
function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothing,
@@ -200,13 +200,24 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
200200
else
201201
rhss = map(x -> x.rhs, eqs)
202202

203+
outvar = :u
203204
if outputidxs === nothing
204205
lhss = map(x -> x.lhs, eqs)
205206
update_vars = collect(Iterators.flatten(map(ModelingToolkit.vars, lhss))) # these are the ones we're chaning
206207
length(update_vars) == length(unique(update_vars)) == length(eqs) ||
207208
error("affected variables not unique, each state can only be affected by one equation for a single `root_eqs => affects` pair.")
208-
stateind(sym) = findfirst(isequal(sym), dvs)
209-
update_inds = stateind.(update_vars)
209+
alleq = all(isequal(isparameter(first(update_vars))),
210+
Iterators.map(isparameter, update_vars))
211+
if !isparameter(first(lhss)) && alleq
212+
stateind(sym) = findfirst(isequal(sym), dvs)
213+
update_inds = stateind.(update_vars)
214+
elseif isparameter(first(lhss)) && alleq
215+
psind(sym) = findfirst(isequal(sym), ps)
216+
update_inds = psind.(update_vars)
217+
outvar = :p
218+
else
219+
error("Error, building an affect function for a callback that wants to modify both parameters and states. This is not currently allowed in one individual callback.")
220+
end
210221
else
211222
update_inds = outputidxs
212223
end
@@ -215,7 +226,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
215226
p = map(x -> time_varying_as_func(value(x), sys), ps)
216227
t = get_iv(sys)
217228
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = expression,
218-
wrap_code = add_integrator_header(),
229+
wrap_code = add_integrator_header(outvar),
219230
outputidxs = update_inds,
220231
kwargs...)
221232
rf_ip

test/root_equations.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,3 +334,25 @@ end
334334
model = Model(sin(30t))
335335
sys = structural_simplify(model)
336336
@test isempty(ModelingToolkit.continuous_events(sys))
337+
338+
339+
# let
340+
# @parameters k t1 t2
341+
# @variables t A(t)
342+
343+
# cond1 = (t == t1)
344+
# affect1 = [A ~ A + 1]
345+
# cb1 = cond1 => affect1
346+
# cond2 = (t == t2)
347+
# affect2 = [k ~ 1.0]
348+
# cb2 = cond2 => affect2
349+
350+
# ∂ₜ = Differential(t)
351+
# eqs = [∂ₜ(A) ~ -k*A]
352+
# @named osys = ODESystem(eqs, t, discrete_events=[cb1,cb2])
353+
# u0 = [A => 1.0]
354+
# p = [k => 0.0, t1 => 1.0, t2 => 2.0]
355+
# tspan = (0.0, 4.0)
356+
# oprob = ODEProblem(osys, u0, tspan, p)
357+
# sol = solve(oprob, Tsit5())
358+
# end

0 commit comments

Comments
 (0)