Skip to content

Commit 25436f8

Browse files
committed
Fixes & tests for callbacks.
1 parent 04c7737 commit 25436f8

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

src/systems/callbacks.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,13 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
269269
p = map(x -> time_varying_as_func(value(x), sys), ps)
270270
t = get_iv(sys)
271271
condit = condition(cb)
272-
build_function(condit, u, t, p; expression, wrap_code = condition_header(), kwargs...)
272+
cs = collect_constants(condit)
273+
if !isempty(cs)
274+
cmap = map(x -> x => getdefault(x), cs)
275+
condit = substitute(condit, cmap)
276+
end
277+
build_function(condit, u, t, p; expression, wrap_code = condition_header(),
278+
kwargs...)
273279
end
274280

275281
function compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)
@@ -337,9 +343,11 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
337343
t = get_iv(sys)
338344
integ = gensym(:MTKIntegrator)
339345
getexpr = (postprocess_affect_expr! === nothing) ? expression : Val{true}
346+
pre = get_preprocess_constants(rhss)
340347
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = getexpr,
341348
wrap_code = add_integrator_header(integ, outvar),
342349
outputidxs = update_inds,
350+
postprocess_fbody = pre,
343351
kwargs...)
344352
# applied user-provided function to the generated expression
345353
if postprocess_affect_expr! !== nothing
@@ -376,7 +384,9 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = states
376384
u = map(x -> time_varying_as_func(value(x), sys), dvs)
377385
p = map(x -> time_varying_as_func(value(x), sys), ps)
378386
t = get_iv(sys)
379-
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = Val{false}, kwargs...)
387+
pre = get_preprocess_constants(rhss)
388+
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = Val{false},
389+
postprocess_fbody = pre, kwargs...)
380390

381391
affect_functions = map(cbs) do cb # Keep affect function separate
382392
eq_aff = affects(cb)

test/funcaffect.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using ModelingToolkit, Test, OrdinaryDiffEq
22

33
@parameters t
4+
@constants h = 1 zr = 0
45
@variables u(t)
56
D = Differential(t)
67

@@ -16,19 +17,19 @@ i4 = findfirst(==(4.0), sol[:t])
1617
@test sol.u[i4 + 1][1] > 10.0
1718

1819
# callback
19-
cb = ModelingToolkit.SymbolicDiscreteCallback(t == 0,
20+
cb = ModelingToolkit.SymbolicDiscreteCallback(t == zr,
2021
(f = affect1!, sts = [], pars = [],
2122
ctx = [1]))
22-
cb1 = ModelingToolkit.SymbolicDiscreteCallback(t == 0, (affect1!, [], [], [1]))
23+
cb1 = ModelingToolkit.SymbolicDiscreteCallback(t == zr, (affect1!, [], [], [1]))
2324
@test ModelingToolkit.affects(cb) isa ModelingToolkit.FunctionalAffect
2425
@test cb == cb1
2526
@test ModelingToolkit.SymbolicDiscreteCallback(cb) === cb # passthrough
2627
@test hash(cb) == hash(cb1)
2728

28-
cb = ModelingToolkit.SymbolicContinuousCallback([t ~ 0],
29+
cb = ModelingToolkit.SymbolicContinuousCallback([t ~ zr],
2930
(f = affect1!, sts = [], pars = [],
3031
ctx = [1]))
31-
cb1 = ModelingToolkit.SymbolicContinuousCallback([t ~ 0], (affect1!, [], [], [1]))
32+
cb1 = ModelingToolkit.SymbolicContinuousCallback([t ~ zr], (affect1!, [], [], [1]))
3233
@test cb == cb1
3334
@test ModelingToolkit.SymbolicContinuousCallback(cb) === cb # passthrough
3435
@test hash(cb) == hash(cb1)
@@ -48,7 +49,7 @@ de = de[1]
4849
@test ModelingToolkit.has_functional_affect(de)
4950

5051
sys2 = ODESystem(eqs, t, [u], [], name = :sys,
51-
discrete_events = [[4.0] => [u ~ -u]])
52+
discrete_events = [[4.0] => [u ~ -u * h]])
5253
@test !ModelingToolkit.has_functional_affect(ModelingToolkit.get_discrete_events(sys2)[1])
5354

5455
# context
@@ -121,7 +122,7 @@ i8 = findfirst(==(8.0), sol[:t])
121122
ctx = [0]
122123
function affect4!(integ, u, p, ctx)
123124
ctx[1] += 1
124-
@test u.resistor₊v == 1
125+
@test u.resistor₊v == h
125126
end
126127
s1 = compose(ODESystem(Equation[], t, [], [], name = :s1,
127128
discrete_events = 1.0 => (affect4!, [resistor.v], [], ctx)),
@@ -268,7 +269,7 @@ function bb_affect!(integ, u, p, ctx)
268269
end
269270

270271
@named bb_model = ODESystem(bb_eqs, t, sts, par,
271-
continuous_events = [[y ~ 0] => (bb_affect!, [v], [], nothing)])
272+
continuous_events = [[y ~ zr] => (bb_affect!, [v], [], nothing)])
272273

273274
bb_sys = structural_simplify(bb_model)
274275
@test ModelingToolkit.affects(ModelingToolkit.continuous_events(bb_sys)) isa

0 commit comments

Comments
 (0)