Skip to content

Commit 52d76af

Browse files
committed
update JumpSystems to use callback affects
1 parent 41a57ab commit 52d76af

File tree

2 files changed

+47
-42
lines changed

2 files changed

+47
-42
lines changed

src/systems/callbacks.jl

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -76,46 +76,64 @@ end
7676

7777
################################# compilation functions ####################################
7878

79+
# handles ensuring that affect! functions work with integrator arguments
80+
function add_integrator_header()
81+
integrator = gensym(:MTKIntegrator)
82+
83+
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [:u, :p, :t])], [],
84+
expr.body),
85+
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [:u, :u, :p, :t])], [],
86+
expr.body)
87+
end
88+
7989
function compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)
8090
compile_affect(affect_equations(cb), args...; kwargs...)
8191
end
8292

8393
"""
84-
compile_affect(eqs::Vector{Equation}, sys, dvs, ps; kwargs...)
94+
compile_affect(eqs::Vector{Equation}, sys, dvs, ps; expression, outputidxs, kwargs...)
8595
compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)
8696
87-
Returns a function that takes an integrator as argument and modifies the state with the affect.
97+
Returns a function that takes an integrator as argument and modifies the state with the
98+
affect. The generated function has the signature `affect!(integrator)`.
99+
100+
Notes
101+
- `expression = Val{true}`, causes the generated function to be returned as an expression.
102+
If set to `Val{false}` a `RuntimeGeneratedFunction` will be returned.
103+
- `outputidxs`, a vector of indices of the output variables.
104+
- `kwargs` are passed through to `Symbolics.build_function`.
88105
"""
89-
function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; expression = Val{false}, kwargs...)
106+
function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothing,
107+
expression = Val{true},
108+
kwargs...)
90109
if isempty(eqs)
91-
return (args...) -> () # We don't do anything in the callback, we're just after the event
110+
if expression == Val{true}
111+
return :((args...) -> ())
112+
else
113+
return (args...) -> () # We don't do anything in the callback, we're just after the event
114+
end
92115
else
93116
rhss = map(x -> x.rhs, eqs)
94-
lhss = map(x -> x.lhs, eqs)
95-
update_vars = collect(Iterators.flatten(map(ModelingToolkit.vars, lhss))) # these are the ones we're chaning
96-
length(update_vars) == length(unique(update_vars)) == length(eqs) ||
97-
error("affected variables not unique, each state can only be affected by one equation for a single `root_eqs => affects` pair.")
98-
vars = states(sys)
99117

100-
u = map(x -> time_varying_as_func(value(x), sys), vars)
118+
if outputidxs === nothing
119+
lhss = map(x -> x.lhs, eqs)
120+
update_vars = collect(Iterators.flatten(map(ModelingToolkit.vars, lhss))) # these are the ones we're chaning
121+
length(update_vars) == length(unique(update_vars)) == length(eqs) ||
122+
error("affected variables not unique, each state can only be affected by one equation for a single `root_eqs => affects` pair.")
123+
stateind(sym) = findfirst(isequal(sym), dvs)
124+
update_inds = stateind.(update_vars)
125+
else
126+
update_inds = outputidxs
127+
end
128+
129+
u = map(x -> time_varying_as_func(value(x), sys), dvs)
101130
p = map(x -> time_varying_as_func(value(x), sys), ps)
102131
t = get_iv(sys)
103-
# stateind(sym) = findfirst(isequal(sym), vars)
104-
# update_inds = stateind.(update_vars)
105-
# rf_oop, rf_ip = build_function(rhss, u, p, t; expression = expression, wrap_code=add_integrator_header(), outputidxs = update_inds, kwargs...)
106-
# rf_ip
107-
108-
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = expression, kwargs...)
109-
110-
stateind(sym) = findfirst(isequal(sym), vars)
111-
112-
update_inds = stateind.(update_vars)
113-
let update_inds = update_inds
114-
function (integ)
115-
lhs = @views integ.u[update_inds]
116-
rf_ip(lhs, integ.u, integ.p, integ.t)
117-
end
118-
end
132+
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = expression,
133+
wrap_code = add_integrator_header(),
134+
outputidxs = update_inds,
135+
kwargs...)
136+
rf_ip
119137
end
120138
end
121139

@@ -150,7 +168,7 @@ function generate_rootfinding_callback(cbs, sys::ODESystem, dvs = states(sys),
150168

151169
affect_functions = map(cbs) do cb # Keep affect function separate
152170
eq_aff = affect_equations(cb)
153-
affect = compile_affect(eq_aff, sys, dvs, ps; kwargs...)
171+
affect = compile_affect(eq_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
154172
end
155173

156174
if length(eqs) == 1

src/systems/jumps/jumpsystem.jl

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -118,23 +118,10 @@ function generate_rate_function(js::JumpSystem, rate)
118118
conv = states_to_sym(states(js)),
119119
expression = Val{true})
120120
end
121-
function add_integrator_header()
122-
integrator = gensym(:MTKIntegrator)
123-
124-
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [:u, :p, :t])], [],
125-
expr.body),
126-
expr -> Func([DestructuredArgs(expr.args, integrator, inds = [:u, :u, :p, :t])], [],
127-
expr.body)
128-
end
129121

130122
function generate_affect_function(js::JumpSystem, affect, outputidxs)
131-
compile_affect(affect, js, states(js), parameters(js); expression = Val{true})
132-
# bf = build_function(map(x -> x isa Equation ? x.rhs : x, affect), states(js),
133-
# parameters(js),
134-
# get_iv(js),
135-
# expression = Val{true},
136-
# wrap_code = add_integrator_header(),
137-
# outputidxs = outputidxs)[2]
123+
compile_affect(affect, js, states(js), parameters(js); outputidxs = outputidxs,
124+
expression = Val{true})
138125
end
139126

140127
function assemble_vrj(js, vrj, statetoid)

0 commit comments

Comments
 (0)