Skip to content

Commit 7a94fdf

Browse files
feat: allow saving discrete variables in symbolic callbacks
1 parent 41e801e commit 7a94fdf

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

src/systems/callbacks.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -394,14 +394,15 @@ function callback_save_header(sys::AbstractSystem, cb)
394394
save_idxs = get(ic.callback_to_clocks, cb, Int[])
395395
isempty(save_idxs) && return (identity, identity)
396396

397-
wrapper = function(expr)
398-
return Func(expr.args, [], LiteralExpr(quote
399-
$(expr.body)
400-
save_idxs = $(save_idxs)
401-
for idx in save_idxs
402-
$(SciMLBase.save_discretes!)($(expr.args[1]), idx)
403-
end
404-
end))
397+
wrapper = function (expr)
398+
return Func(expr.args, [],
399+
LiteralExpr(quote
400+
$(expr.body)
401+
save_idxs = $(save_idxs)
402+
for idx in save_idxs
403+
$(SciMLBase.save_discretes!)($(expr.args[1]), idx)
404+
end
405+
end))
405406
end
406407

407408
return wrapper, wrapper
@@ -705,7 +706,9 @@ function compile_user_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs.
705706
else
706707
save_idxs = Int[]
707708
end
708-
let u = u, p = p, user_affect = func(affect), ctx = context(affect), save_idxs = save_idxs
709+
let u = u, p = p, user_affect = func(affect), ctx = context(affect),
710+
save_idxs = save_idxs
711+
709712
function (integ)
710713
user_affect(integ, u, p, ctx)
711714
for idx in save_idxs

src/systems/jumps/jumpsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,8 @@ function generate_affect_function(js::JumpSystem, affect, outputidxs)
221221
csubs = Dict(c => getdefault(c) for c in consts)
222222
affect = substitute(affect, csubs)
223223
end
224-
compile_affect(affect, nothing, js, unknowns(js), parameters(js); outputidxs = outputidxs,
224+
compile_affect(
225+
affect, nothing, js, unknowns(js), parameters(js); outputidxs = outputidxs,
225226
expression = Val{true}, checkvars = false)
226227
end
227228

test/symbolic_events.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -876,12 +876,13 @@ end
876876
cb2 = [x ~ 0.5] => (save_affect!, [], [b], [b], nothing)
877877
cb3 = 1.0 => [c ~ t]
878878

879-
@mtkbuild sys = ODESystem(D(x) ~ cos(x), t, [x], [a, b, c]; continuous_events = [cb1, cb2], discrete_events = [cb3])
879+
@mtkbuild sys = ODESystem(D(x) ~ cos(t), t, [x], [a, b, c];
880+
continuous_events = [cb1, cb2], discrete_events = [cb3])
880881
prob = ODEProblem(sys, [x => 1.0], (0.0, 2pi), [a => 1.0, b => 2.0, c => 0.0])
881882
@test sort(canonicalize(Discrete(), prob.p)[1]) == [0.0, 1.0, 2.0]
882883
sol = solve(prob, Tsit5())
883884

884885
@test sol[a] == [-1.0]
885886
@test sol[b] == [5.0, 5.0]
886-
@test sol[c] == [1.0, 2.0, 3.0]
887+
@test sol[c] == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
887888
end

0 commit comments

Comments
 (0)