Skip to content

Commit bbda412

Browse files
committed
Add support for the initializealg argument in SciMLBase callbacks
1 parent 57dcc7e commit bbda412

File tree

2 files changed

+160
-15
lines changed

2 files changed

+160
-15
lines changed

src/systems/callbacks.jl

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,25 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either:
106106
+ `read_parameters` is a vector of the parameters that are *used* by `f!`. Their indices are passed to `f` in `p` similarly to the indices of `unknowns` passed in `u`.
107107
+ `modified_parameters` is a vector of the parameters that are *modified* by `f!`. Note that a parameter will not appear in `p` if it only appears in `modified_parameters`; it must appear in both `parameters` and `modified_parameters` if it is used in the affect definition.
108108
+ `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem.
109+
110+
Callbacks that impact a DAE are applied, then the DAE is reinitialized using `reinitializealg` (which defaults to `SciMLBase.CheckInit`).
111+
This reinitialization algorithm ensures that the DAE is satisfied after the callback runs. The default value of `CheckInit` will simply validate
112+
that the newly-assigned values indeed satisfy the algebraic system; see the documentation on DAE initialization for a more detailed discussion of
113+
initialization.
109114
"""
110115
struct SymbolicContinuousCallback
111116
eqs::Vector{Equation}
112117
affect::Union{Vector{Equation}, FunctionalAffect}
113118
affect_neg::Union{Vector{Equation}, FunctionalAffect, Nothing}
114119
rootfind::SciMLBase.RootfindOpt
115-
function SymbolicContinuousCallback(; eqs::Vector{Equation}, affect = NULL_AFFECT,
116-
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
117-
new(eqs, make_affect(affect), make_affect(affect_neg), rootfind)
120+
reinitializealg::SciMLBase.DAEInitializationAlgorithm
121+
function SymbolicContinuousCallback(;
122+
eqs::Vector{Equation},
123+
affect = NULL_AFFECT,
124+
affect_neg = affect,
125+
rootfind = SciMLBase.LeftRootFind,
126+
reinitializealg=SciMLBase.CheckInit())
127+
new(eqs, make_affect(affect), make_affect(affect_neg), rootfind, reinitializealg)
118128
end # Default affect to nothing
119129
end
120130
make_affect(affect) = affect
@@ -183,6 +193,10 @@ function affect_negs(cbs::Vector{SymbolicContinuousCallback})
183193
mapreduce(affect_negs, vcat, cbs, init = Equation[])
184194
end
185195

196+
reinitialization_alg(cb::SymbolicContinuousCallback) = cb.reinitializealg
197+
reinitialization_algs(cbs::Vector{SymbolicContinuousCallback}) =
198+
mapreduce(reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[])
199+
186200
namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af]
187201
namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s)
188202
namespace_affects(::Nothing, s) = nothing
@@ -225,11 +239,12 @@ struct SymbolicDiscreteCallback
225239
# TODO: Iterative
226240
condition::Any
227241
affects::Any
242+
reinitializealg::SciMLBase.DAEInitializationAlgorithm
228243

229-
function SymbolicDiscreteCallback(condition, affects = NULL_AFFECT)
244+
function SymbolicDiscreteCallback(condition, affects = NULL_AFFECT, reinitializealg=SciMLBase.CheckInit())
230245
c = scalarize_condition(condition)
231246
a = scalarize_affects(affects)
232-
new(c, a)
247+
new(c, a, reinitializealg)
233248
end # Default affect to nothing
234249
end
235250

@@ -286,6 +301,10 @@ function affects(cbs::Vector{SymbolicDiscreteCallback})
286301
reduce(vcat, affects(cb) for cb in cbs; init = [])
287302
end
288303

304+
reinitialization_alg(cb::SymbolicDiscreteCallback) = cb.reinitializealg
305+
reinitialization_algs(cbs::Vector{SymbolicDiscreteCallback}) =
306+
mapreduce(reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[])
307+
289308
function namespace_callback(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCallback
290309
af = affects(cb)
291310
af = af isa AbstractVector ? namespace_affect.(af, Ref(s)) : namespace_affect(af, s)
@@ -579,13 +598,14 @@ function generate_single_rootfinding_callback(
579598
initfn = SciMLBase.INITIALIZE_DEFAULT
580599
end
581600
return ContinuousCallback(
582-
cond, affect_function.affect, affect_function.affect_neg,
583-
rootfind = cb.rootfind, initialize = initfn)
601+
cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind,
602+
initialize = initfn,
603+
initializealg = reinitialization_alg(cb))
584604
end
585605

586606
function generate_vector_rootfinding_callback(
587607
cbs, sys::AbstractODESystem, dvs = unknowns(sys),
588-
ps = parameters(sys); rootfind = SciMLBase.RightRootFind, kwargs...)
608+
ps = parameters(sys); rootfind = SciMLBase.RightRootFind, reinitialization = SciMLBase.CheckInit(), kwargs...)
589609
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
590610
num_eqs = length.(eqs)
591611
# fuse equations to create VectorContinuousCallback
@@ -650,7 +670,7 @@ function generate_vector_rootfinding_callback(
650670
initfn = SciMLBase.INITIALIZE_DEFAULT
651671
end
652672
return VectorContinuousCallback(
653-
cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initfn)
673+
cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initfn, initializealg = reinitialization)
654674
end
655675

656676
"""
@@ -690,18 +710,22 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
690710
# group the cbs by what rootfind op they use
691711
# groupby would be very useful here, but alas
692712
cb_classes = Dict{
693-
@NamedTuple{rootfind::SciMLBase.RootfindOpt}, Vector{SymbolicContinuousCallback}}()
713+
@NamedTuple{
714+
rootfind::SciMLBase.RootfindOpt,
715+
reinitialization::SciMLBase.DAEInitializationAlgorithm}, Vector{SymbolicContinuousCallback}}()
694716
for cb in cbs
695717
push!(
696-
get!(() -> SymbolicContinuousCallback[], cb_classes, (rootfind = cb.rootfind,)),
718+
get!(() -> SymbolicContinuousCallback[], cb_classes, (
719+
rootfind = cb.rootfind,
720+
reinitialization = reinitialization_alg(cb))),
697721
cb)
698722
end
699723

700724
# generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order
701725
compiled_callbacks = map(collect(pairs(sort!(
702726
OrderedDict(cb_classes); by = p -> p.rootfind)))) do (equiv_class, cbs_in_class)
703727
return generate_vector_rootfinding_callback(
704-
cbs_in_class, sys, dvs, ps; rootfind = equiv_class.rootfind, kwargs...)
728+
cbs_in_class, sys, dvs, ps; rootfind = equiv_class.rootfind, reinitialization=equiv_class.reinitialization, kwargs...)
705729
end
706730
if length(compiled_callbacks) == 1
707731
return compiled_callbacks[]
@@ -772,10 +796,10 @@ function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = no
772796
end
773797
if cond isa AbstractVector
774798
# Preset Time
775-
return PresetTimeCallback(cond, as; initialize = initfn)
799+
return PresetTimeCallback(cond, as; initialize = initfn, initializealg=reinitialization_alg(cb))
776800
else
777801
# Periodic
778-
return PeriodicCallback(as, cond; initialize = initfn)
802+
return PeriodicCallback(as, cond; initialize = initfn, initializealg=reinitialization_alg(cb))
779803
end
780804
end
781805

@@ -800,7 +824,7 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
800824
else
801825
initfn = SciMLBase.INITIALIZE_DEFAULT
802826
end
803-
return DiscreteCallback(c, as; initialize = initfn)
827+
return DiscreteCallback(c, as; initialize = initfn, initializealg=reinitialization_alg(cb))
804828
end
805829
end
806830

test/symbolic_events.jl

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,108 @@ end
867867
@test sign.(cos.(3 * (required_crossings_c2 .+ 1e-6))) == sign.(last.(cr2))
868868
end
869869

870+
@testset "Discrete variable timeseries" begin
871+
@variables x(t)
872+
@parameters a(t) b(t) c(t)
873+
cb1 = [x ~ 1.0] => [a ~ -a]
874+
function save_affect!(integ, u, p, ctx)
875+
integ.ps[p.b] = 5.0
876+
end
877+
cb2 = [x ~ 0.5] => (save_affect!, [], [b], [b], nothing)
878+
cb3 = 1.0 => [c ~ t]
879+
880+
@mtkbuild sys = ODESystem(D(x) ~ cos(t), t, [x], [a, b, c];
881+
continuous_events = [cb1, cb2], discrete_events = [cb3])
882+
prob = ODEProblem(sys, [x => 1.0], (0.0, 2pi), [a => 1.0, b => 2.0, c => 0.0])
883+
@test sort(canonicalize(Discrete(), prob.p)[1]) == [0.0, 1.0, 2.0]
884+
sol = solve(prob, Tsit5())
885+
886+
@test sol[a] == [-1.0]
887+
@test sol[b] == [5.0, 5.0]
888+
@test sol[c] == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
889+
end
890+
891+
@testset "Discrete event reinitialization (#3142)" begin
892+
@connector LiquidPort begin
893+
p(t)::Float64, [ description = "Set pressure in bar",
894+
guess = 1.01325]
895+
Vdot(t)::Float64, [ description = "Volume flow rate in L/min",
896+
guess = 0.0,
897+
connect = Flow]
898+
end
899+
900+
@mtkmodel PressureSource begin
901+
@components begin
902+
port = LiquidPort()
903+
end
904+
@parameters begin
905+
p_set::Float64 = 1.01325, [description = "Set pressure in bar"]
906+
end
907+
@equations begin
908+
port.p ~ p_set
909+
end
910+
end
911+
912+
@mtkmodel BinaryValve begin
913+
@constants begin
914+
p_ref::Float64 = 1.0, [description = "Reference pressure drop in bar"]
915+
ρ_ref::Float64 = 1000.0, [description = "Reference density in kg/m^3"]
916+
end
917+
@components begin
918+
port_in = LiquidPort()
919+
port_out = LiquidPort()
920+
end
921+
@parameters begin
922+
k_V::Float64 = 1.0, [description = "Valve coefficient in L/min/bar"]
923+
k_leakage::Float64 = 1e-08, [description = "Leakage coefficient in L/min/bar"]
924+
ρ::Float64 = 1000.0, [description = "Density in kg/m^3"]
925+
end
926+
@variables begin
927+
S(t)::Float64, [description = "Valve state", guess = 1.0, irreducible = true]
928+
Δp(t)::Float64, [description = "Pressure difference in bar", guess = 1.0]
929+
Vdot(t)::Float64, [description = "Volume flow rate in L/min", guess = 1.0]
930+
end
931+
@equations begin
932+
# Port handling
933+
port_in.Vdot ~ -Vdot
934+
port_out.Vdot ~ Vdot
935+
Δp ~ port_in.p - port_out.p
936+
# System behavior
937+
D(S) ~ 0.0
938+
Vdot ~ S*k_V*sign(Δp)*sqrt(abs(Δp)/p_ref * ρ_ref/ρ) + k_leakage*Δp # softplus alpha function to avoid negative values under the sqrt
939+
end
940+
end
941+
942+
# Test System
943+
@mtkmodel TestSystem begin
944+
@components begin
945+
pressure_source_1 = PressureSource(p_set = 2.0)
946+
binary_valve_1 = BinaryValve(S = 1.0, k_leakage=0.0)
947+
binary_valve_2 = BinaryValve(S = 1.0, k_leakage=0.0)
948+
pressure_source_2 = PressureSource(p_set = 1.0)
949+
end
950+
@equations begin
951+
connect(pressure_source_1.port, binary_valve_1.port_in)
952+
connect(binary_valve_1.port_out, binary_valve_2.port_in)
953+
connect(binary_valve_2.port_out, pressure_source_2.port)
954+
end
955+
@discrete_events begin
956+
[30] => [binary_valve_1.S ~ 0.0, binary_valve_2.Δp ~ 0.0 ]
957+
[60] => [binary_valve_1.S ~ 1.0, binary_valve_2.S ~ 0.0, binary_valve_2.Δp ~ 1.0 ]
958+
[120] => [binary_valve_1.S ~ 0.0, binary_valve_2.Δp ~ 0.0 ]
959+
end
960+
end
961+
962+
# Test Simulation
963+
@mtkbuild sys = TestSystem()
964+
965+
# Test Simulation
966+
prob = ODEProblem(sys, [], (0.0, 150.0))
967+
sol = solve(prob)
968+
@test sol[end] == [0.0, 0.0, 0.0]
969+
end
970+
971+
870972
@testset "Discrete variable timeseries" begin
871973
@variables x(t)
872974
@parameters a(t) b(t) c(t)
@@ -887,3 +989,22 @@ end
887989
@test sol[b] == [2.0, 5.0, 5.0]
888990
@test sol[c] == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
889991
end
992+
993+
@testset "Bump" begin
994+
@variables x(t) [irreducible=true] y(t) [irreducible=true]
995+
eqs = [x ~ y, D(x) ~ -1]
996+
cb = [x ~ 0.0] => [x ~ 0, y ~ 1]
997+
@mtkbuild pend = ODESystem(eqs, t;continuous_events = [cb])
998+
prob = ODEProblem(pend, [x => 1], (0.0, 3.0), guesses = [y => x])
999+
@test_throws "initialization not satisifed" solve(prob, Rodas5())
1000+
1001+
cb = [x ~ 0.0] => [y ~ 1]
1002+
@mtkbuild pend = ODESystem(eqs, t;continuous_events = [cb])
1003+
prob = ODEProblem(pend, [x => 1], (0.0, 3.0), guesses = [y => x])
1004+
@test_broken !SciMLBase.successful_retcode(solve(prob, Rodas5()))
1005+
1006+
cb = [x ~ 0.0] => [x ~ 1, y ~ 1]
1007+
@mtkbuild pend = ODESystem(eqs, t;continuous_events = [cb])
1008+
prob = ODEProblem(pend, [x => 1], (0.0, 3.0), guesses = [y => x])
1009+
@test all((0.0; atol=1e-9), solve(prob, Rodas5())[[x,y]][end])
1010+
end

0 commit comments

Comments
 (0)