Skip to content

Commit f57215a

Browse files
committed
Add support for the initializealg argument in SciMLBase callbacks
1 parent 95fa1ee commit f57215a

File tree

2 files changed

+34
-13
lines changed

2 files changed

+34
-13
lines changed

src/systems/callbacks.jl

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,11 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either:
216216
+ `modified_parameters` is a vector of the parameters that are *modified* by `f!`. Note that a parameter will not appear in `p` if it only appears in `modified_parameters`; it must appear in both `parameters` and `modified_parameters` if it is used in the affect definition.
217217
+ `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem.
218218
* A [`MutatingFunctionalAffect`](@ref); refer to its documentation for details.
219+
220+
Callbacks that impact a DAE are applied, then the DAE is reinitialized using `reinitializealg` (which defaults to `SciMLBase.CheckInit`).
221+
This reinitialization algorithm ensures that the DAE is satisfied after the callback runs. The default value of `CheckInit` will simply validate
222+
that the newly-assigned values indeed satisfy the algebraic system; see the documentation on DAE initialization for a more detailed discussion of
223+
initialization.
219224
"""
220225
struct SymbolicContinuousCallback
221226
eqs::Vector{Equation}
@@ -224,14 +229,16 @@ struct SymbolicContinuousCallback
224229
affect::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect}
225230
affect_neg::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect, Nothing}
226231
rootfind::SciMLBase.RootfindOpt
232+
reinitializealg::SciMLBase.DAEInitializationAlgorithm
227233
function SymbolicContinuousCallback(;
228234
eqs::Vector{Equation},
229235
affect = NULL_AFFECT,
230236
affect_neg = affect,
231237
rootfind = SciMLBase.LeftRootFind,
232238
initialize=NULL_AFFECT,
233-
finalize=NULL_AFFECT)
234-
new(eqs, initialize, finalize, make_affect(affect), make_affect(affect_neg), rootfind)
239+
finalize=NULL_AFFECT,
240+
reinitializealg=SciMLBase.CheckInit())
241+
new(eqs, initialize, finalize, make_affect(affect), make_affect(affect_neg), rootfind, reinitializealg)
235242
end # Default affect to nothing
236243
end
237244
make_affect(affect) = affect
@@ -373,6 +380,10 @@ function finalize_affects(cbs::Vector{SymbolicContinuousCallback})
373380
mapreduce(finalize_affects, vcat, cbs, init = Equation[])
374381
end
375382

383+
reinitialization_alg(cb::SymbolicContinuousCallback) = cb.reinitializealg
384+
reinitialization_algs(cbs::Vector{SymbolicContinuousCallback}) =
385+
mapreduce(reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[])
386+
376387
namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af]
377388
namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s)
378389
namespace_affects(af::MutatingFunctionalAffect, s) = namespace_affect(af, s)
@@ -419,11 +430,12 @@ struct SymbolicDiscreteCallback
419430
# TODO: Iterative
420431
condition::Any
421432
affects::Any
433+
reinitializealg::SciMLBase.DAEInitializationAlgorithm
422434

423-
function SymbolicDiscreteCallback(condition, affects = NULL_AFFECT)
435+
function SymbolicDiscreteCallback(condition, affects = NULL_AFFECT, reinitializealg=SciMLBase.CheckInit())
424436
c = scalarize_condition(condition)
425437
a = scalarize_affects(affects)
426-
new(c, a)
438+
new(c, a, reinitializealg)
427439
end # Default affect to nothing
428440
end
429441

@@ -481,6 +493,10 @@ function affects(cbs::Vector{SymbolicDiscreteCallback})
481493
reduce(vcat, affects(cb) for cb in cbs; init = [])
482494
end
483495

496+
reinitialization_alg(cb::SymbolicDiscreteCallback) = cb.reinitializealg
497+
reinitialization_algs(cbs::Vector{SymbolicDiscreteCallback}) =
498+
mapreduce(reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[])
499+
484500
function namespace_callback(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCallback
485501
af = affects(cb)
486502
af = af isa AbstractVector ? namespace_affect.(af, Ref(s)) : namespace_affect(af, s)
@@ -776,12 +792,13 @@ function generate_single_rootfinding_callback(
776792
return ContinuousCallback(
777793
cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind,
778794
initialize = isnothing(affect_function.initialize) ? SciMLBase.INITIALIZE_DEFAULT : (c, u, t, i) -> affect_function.initialize(i),
779-
finalize = isnothing(affect_function.finalize) ? SciMLBase.FINALIZE_DEFAULT : (c, u, t, i) -> affect_function.finalize(i))
795+
finalize = isnothing(affect_function.finalize) ? SciMLBase.FINALIZE_DEFAULT : (c, u, t, i) -> affect_function.finalize(i),
796+
initializealg = reinitialization_alg(cb))
780797
end
781798

782799
function generate_vector_rootfinding_callback(
783800
cbs, sys::AbstractODESystem, dvs = unknowns(sys),
784-
ps = parameters(sys); rootfind = SciMLBase.RightRootFind, kwargs...)
801+
ps = parameters(sys); rootfind = SciMLBase.RightRootFind, reinitialization = SciMLBase.CheckInit(), kwargs...)
785802
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
786803
num_eqs = length.(eqs)
787804
# fuse equations to create VectorContinuousCallback
@@ -847,7 +864,7 @@ function generate_vector_rootfinding_callback(
847864
initialize = handle_optional_setup_fn(map(fn -> fn.initialize, affect_functions), SciMLBase.INITIALIZE_DEFAULT)
848865
finalize = handle_optional_setup_fn(map(fn -> fn.finalize, affect_functions), SciMLBase.FINALIZE_DEFAULT)
849866
return VectorContinuousCallback(
850-
cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initialize, finalize = finalize)
867+
cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initialize, finalize = finalize, initializealg = reinitialization)
851868
end
852869

853870
"""
@@ -893,18 +910,22 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
893910
# group the cbs by what rootfind op they use
894911
# groupby would be very useful here, but alas
895912
cb_classes = Dict{
896-
@NamedTuple{rootfind::SciMLBase.RootfindOpt}, Vector{SymbolicContinuousCallback}}()
913+
@NamedTuple{
914+
rootfind::SciMLBase.RootfindOpt,
915+
reinitialization::SciMLBase.DAEInitializationAlgorithm}, Vector{SymbolicContinuousCallback}}()
897916
for cb in cbs
898917
push!(
899-
get!(() -> SymbolicContinuousCallback[], cb_classes, (rootfind = cb.rootfind,)),
918+
get!(() -> SymbolicContinuousCallback[], cb_classes, (
919+
rootfind = cb.rootfind,
920+
reinitialization = reinitialization_alg(cb))),
900921
cb)
901922
end
902923

903924
# generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order
904925
compiled_callbacks = map(collect(pairs(sort!(
905926
OrderedDict(cb_classes); by = p -> p.rootfind)))) do (equiv_class, cbs_in_class)
906927
return generate_vector_rootfinding_callback(
907-
cbs_in_class, sys, dvs, ps; rootfind = equiv_class.rootfind, kwargs...)
928+
cbs_in_class, sys, dvs, ps; rootfind = equiv_class.rootfind, reinitialization=equiv_class.reinitialization, kwargs...)
908929
end
909930
if length(compiled_callbacks) == 1
910931
return compiled_callbacks[]

test/symbolic_events.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -996,8 +996,8 @@ end
996996
@test sort(canonicalize(Discrete(), prob.p)[1]) == [0.0, 1.0, 2.0]
997997
sol = solve(prob, Tsit5())
998998

999-
@test sol[a] == [1.0, -1.0]
1000-
@test sol[b] == [2.0, 5.0, 5.0]
999+
@test sol[a] == [-1.0]
1000+
@test sol[b] == [5.0, 5.0]
10011001
@test sol[c] == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
10021002
end
10031003
@testset "Heater" begin
@@ -1198,5 +1198,5 @@ end
11981198
ss = structural_simplify(sys)
11991199
prob = ODEProblem(ss, [theta => 0.0], (0.0, pi))
12001200
sol = solve(prob, Tsit5(); dtmax = 0.01)
1201-
@test sol[cnt] == 197 # we get 2 pulses per phase cycle (cos 0 crossing) and we go to 100 cycles; we miss a few due to the initial state
1201+
@test getp(sol, cnt)(sol) == 197 # we get 2 pulses per phase cycle (cos 0 crossing) and we go to 100 cycles; we miss a few due to the initial state
12021202
end

0 commit comments

Comments
 (0)