Skip to content

Commit 6424ca9

Browse files
committed
Initialize and finalize for discrete callbacks
1 parent 9e63823 commit 6424ca9

File tree

2 files changed

+111
-25
lines changed

2 files changed

+111
-25
lines changed

src/systems/callbacks.jl

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -328,13 +328,16 @@ struct SymbolicDiscreteCallback
328328
# TODO: Iterative
329329
condition::Any
330330
affects::Any
331+
initialize::Any
332+
finalize::Any
331333
reinitializealg::SciMLBase.DAEInitializationAlgorithm
332334

333335
function SymbolicDiscreteCallback(
334-
condition, affects = NULL_AFFECT, reinitializealg = SciMLBase.CheckInit())
336+
condition, affects = NULL_AFFECT; reinitializealg = SciMLBase.CheckInit(),
337+
initialize=NULL_AFFECT, finalize=NULL_AFFECT)
335338
c = scalarize_condition(condition)
336339
a = scalarize_affects(affects)
337-
new(c, a, reinitializealg)
340+
new(c, a, scalarize_affects(initialize), scalarize_affects(finalize), reinitializealg)
338341
end # Default affect to nothing
339342
end
340343

@@ -373,11 +376,16 @@ function Base.show(io::IO, db::SymbolicDiscreteCallback)
373376
end
374377

375378
function Base.:(==)(e1::SymbolicDiscreteCallback, e2::SymbolicDiscreteCallback)
376-
isequal(e1.condition, e2.condition) && isequal(e1.affects, e2.affects)
379+
isequal(e1.condition, e2.condition) && isequal(e1.affects, e2.affects) &&
380+
isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize)
377381
end
378382
function Base.hash(cb::SymbolicDiscreteCallback, s::UInt)
379383
s = hash(cb.condition, s)
380-
cb.affects isa AbstractVector ? foldr(hash, cb.affects, init = s) : hash(cb.affects, s)
384+
s = cb.affects isa AbstractVector ? foldr(hash, cb.affects, init = s) : hash(cb.affects, s)
385+
s = cb.initialize isa AbstractVector ? foldr(hash, cb.initialize, init = s) : hash(cb.initialize, s)
386+
s = cb.finalize isa AbstractVector ? foldr(hash, cb.finalize, init = s) : hash(cb.finalize, s)
387+
s = hash(cb.reinitializealg, s)
388+
return s
381389
end
382390

383391
condition(cb::SymbolicDiscreteCallback) = cb.condition
@@ -397,10 +405,23 @@ function reinitialization_algs(cbs::Vector{SymbolicDiscreteCallback})
397405
reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[])
398406
end
399407

408+
409+
initialize_affects(cb::SymbolicDiscreteCallback) = cb.initialize
410+
function initialize_affects(cbs::Vector{SymbolicDiscreteCallback})
411+
mapreduce(initialize_affects, vcat, cbs, init = Equation[])
412+
end
413+
414+
finalize_affects(cb::SymbolicDiscreteCallback) = cb.finalize
415+
function finalize_affects(cbs::Vector{SymbolicDiscreteCallback})
416+
mapreduce(finalize_affects, vcat, cbs, init = Equation[])
417+
end
418+
400419
function namespace_callback(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCallback
401-
af = affects(cb)
402-
af = af isa AbstractVector ? namespace_affect.(af, Ref(s)) : namespace_affect(af, s)
403-
SymbolicDiscreteCallback(namespace_condition(condition(cb), s), af)
420+
function namespace_affects(af)
421+
return af isa AbstractVector ? namespace_affect.(af, Ref(s)) : namespace_affect(af, s)
422+
end
423+
SymbolicDiscreteCallback(namespace_condition(condition(cb), s), namespace_affects(affects(cb)),
424+
reinitializealg=cb.reinitializealg, initialize=namespace_affects(initialize_affects(cb)), finalize=namespace_affects(finalize_affects(cb)))
404425
end
405426

406427
SymbolicDiscreteCallbacks(cb::Pair) = SymbolicDiscreteCallback[SymbolicDiscreteCallback(cb)]
@@ -773,10 +794,10 @@ function generate_vector_rootfinding_callback(
773794
let save_idxs = save_idxs
774795
if !isnothing(fn.initialize)
775796
(i) -> begin
797+
fn.initialize(i)
776798
for idx in save_idxs
777799
SciMLBase.save_discretes!(i, idx)
778800
end
779-
fn.initialize(i)
780801
end
781802
else
782803
(i) -> begin
@@ -809,20 +830,13 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
809830
eq_aff = affects(cb)
810831
eq_neg_aff = affect_negs(cb)
811832
affect = compile_affect(eq_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
812-
function compile_optional_affect(aff, default = nothing)
813-
if isnothing(aff) || aff == default
814-
return nothing
815-
else
816-
return compile_affect(aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
817-
end
818-
end
819833
if eq_neg_aff === eq_aff
820834
affect_neg = affect
821835
else
822-
affect_neg = compile_optional_affect(eq_neg_aff)
836+
affect_neg = _compile_optional_affect(NULL_AFFECT, eq_neg_aff, cb, sys, dvs, ps; kwargs...)
823837
end
824-
initialize = compile_optional_affect(initialize_affects(cb), NULL_AFFECT)
825-
finalize = compile_optional_affect(finalize_affects(cb), NULL_AFFECT)
838+
initialize = _compile_optional_affect(NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...)
839+
finalize = _compile_optional_affect(NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...)
826840
(affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize)
827841
end
828842

@@ -914,31 +928,48 @@ function compile_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs...)
914928
compile_user_affect(affect, cb, sys, dvs, ps; kwargs...)
915929
end
916930

931+
932+
function _compile_optional_affect(default, aff, cb, sys, dvs, ps; kwargs...)
933+
if isnothing(aff) || aff == default
934+
return nothing
935+
else
936+
return compile_affect(aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
937+
end
938+
end
917939
function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = nothing,
918940
kwargs...)
919941
cond = condition(cb)
920942
as = compile_affect(affects(cb), cb, sys, dvs, ps; expression = Val{false},
921943
postprocess_affect_expr!, kwargs...)
944+
945+
user_initfun = _compile_optional_affect(NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...)
946+
user_finfun = _compile_optional_affect(NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...)
922947
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing &&
923948
(save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing
924-
initfn = let save_idxs = save_idxs
949+
initfn = let
950+
save_idxs = save_idxs
951+
initfun=user_initfun
925952
function (cb, u, t, integrator)
953+
if !isnothing(initfun)
954+
initfun(integrator)
955+
end
926956
for idx in save_idxs
927957
SciMLBase.save_discretes!(integrator, idx)
928958
end
929959
end
930960
end
931961
else
932-
initfn = SciMLBase.INITIALIZE_DEFAULT
962+
initfn = isnothing(user_initfun) ? SciMLBase.INITIALIZE_DEFAULT : (_,_,_,i) -> user_initfun(i)
933963
end
964+
finfun = isnothing(user_finfun) ? SciMLBase.FINALIZE_DEFAULT : (_,_,_,i) -> user_finfun(i)
934965
if cond isa AbstractVector
935966
# Preset Time
936967
return PresetTimeCallback(
937-
cond, as; initialize = initfn, initializealg = reinitialization_alg(cb))
968+
cond, as; initialize = initfn, finalize=finfun, initializealg = reinitialization_alg(cb))
938969
else
939970
# Periodic
940971
return PeriodicCallback(
941-
as, cond; initialize = initfn, initializealg = reinitialization_alg(cb))
972+
as, cond; initialize = initfn, finalize=finfun, initializealg = reinitialization_alg(cb))
942973
end
943974
end
944975

@@ -951,20 +982,27 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
951982
c = compile_condition(cb, sys, dvs, ps; expression = Val{false}, kwargs...)
952983
as = compile_affect(affects(cb), cb, sys, dvs, ps; expression = Val{false},
953984
postprocess_affect_expr!, kwargs...)
985+
986+
user_initfun = _compile_optional_affect(NULL_AFFECT, initialize_affects(cb), cb, sys, dvs, ps; kwargs...)
987+
user_finfun = _compile_optional_affect(NULL_AFFECT, finalize_affects(cb), cb, sys, dvs, ps; kwargs...)
954988
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing &&
955989
(save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing
956-
initfn = let save_idxs = save_idxs
990+
initfn = let save_idxs = save_idxs, initfun = user_initfun
957991
function (cb, u, t, integrator)
992+
if !isnothing(initfun)
993+
initfun(integrator)
994+
end
958995
for idx in save_idxs
959996
SciMLBase.save_discretes!(integrator, idx)
960997
end
961998
end
962999
end
9631000
else
964-
initfn = SciMLBase.INITIALIZE_DEFAULT
1001+
initfn = isnothing(user_initfun) ? SciMLBase.INITIALIZE_DEFAULT : (_,_,_,i) -> user_initfun(i)
9651002
end
1003+
finfun = isnothing(user_finfun) ? SciMLBase.FINALIZE_DEFAULT : (_,_,_,i) -> user_finfun(i)
9661004
return DiscreteCallback(
967-
c, as; initialize = initfn, initializealg = reinitialization_alg(cb))
1005+
c, as; initialize = initfn, finalize = finfun, initializealg = reinitialization_alg(cb))
9681006
end
9691007
end
9701008

test/symbolic_events.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,54 @@ end
10011001
@test seen == true
10021002
@test inited == true
10031003
@test finaled == true
1004+
1005+
#periodic
1006+
inited = false
1007+
finaled = false
1008+
cb3 = ModelingToolkit.SymbolicDiscreteCallback(1.0, [x ~ 2], initialize=a, finalize=b)
1009+
@mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; discrete_events = [cb3])
1010+
prob = ODEProblem(sys, [x => 1.0], (0.0, 2), [])
1011+
sol = solve(prob, Tsit5())
1012+
@test inited == true
1013+
@test finaled == true
1014+
@test isapprox(sol[x][3], 0.0, atol=1e-9)
1015+
@test sol[x][4] 2.0
1016+
@test sol[x][5] 1.0
1017+
1018+
1019+
seen = false
1020+
inited = false
1021+
finaled = false
1022+
cb3 = ModelingToolkit.SymbolicDiscreteCallback(1.0, f, initialize=a, finalize=b)
1023+
@mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; discrete_events = [cb3])
1024+
prob = ODEProblem(sys, [x => 1.0], (0.0, 2), [])
1025+
sol = solve(prob, Tsit5())
1026+
@test seen == true
1027+
@test inited == true
1028+
1029+
#preset
1030+
seen = false
1031+
inited = false
1032+
finaled = false
1033+
cb3 = ModelingToolkit.SymbolicDiscreteCallback([1.0], f, initialize=a, finalize=b)
1034+
@mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; discrete_events = [cb3])
1035+
prob = ODEProblem(sys, [x => 1.0], (0.0, 2), [])
1036+
sol = solve(prob, Tsit5())
1037+
@test seen == true
1038+
@test inited == true
1039+
@test finaled == true
1040+
1041+
#equational
1042+
seen = false
1043+
inited = false
1044+
finaled = false
1045+
cb3 = ModelingToolkit.SymbolicDiscreteCallback(t == 1.0, f, initialize=a, finalize=b)
1046+
@mtkbuild sys = ODESystem(D(x) ~ -1, t, [x], []; discrete_events = [cb3])
1047+
prob = ODEProblem(sys, [x => 1.0], (0.0, 2), [])
1048+
sol = solve(prob, Tsit5(); tstops=1.0)
1049+
@test seen == true
1050+
@test inited == true
1051+
@test finaled == true
10041052
end
10051053

10061054
@testset "Bump" begin

0 commit comments

Comments
 (0)