diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 86cab57634..c427ae02a1 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -106,15 +106,25 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either: + `read_parameters` is a vector of the parameters that are *used* by `f!`. Their indices are passed to `f` in `p` similarly to the indices of `unknowns` passed in `u`. + `modified_parameters` is a vector of the parameters that are *modified* by `f!`. Note that a parameter will not appear in `p` if it only appears in `modified_parameters`; it must appear in both `parameters` and `modified_parameters` if it is used in the affect definition. + `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem. + +DAEs will be reinitialized using `reinitializealg` (which defaults to `SciMLBase.CheckInit`) after callbacks are applied. +This reinitialization algorithm ensures that the DAE is satisfied after the callback runs. The default value of `CheckInit` will simply validate +that the newly-assigned values indeed satisfy the algebraic system; see the documentation on DAE initialization for a more detailed discussion of +initialization. """ struct SymbolicContinuousCallback eqs::Vector{Equation} affect::Union{Vector{Equation}, FunctionalAffect} affect_neg::Union{Vector{Equation}, FunctionalAffect, Nothing} rootfind::SciMLBase.RootfindOpt - function SymbolicContinuousCallback(; eqs::Vector{Equation}, affect = NULL_AFFECT, - affect_neg = affect, rootfind = SciMLBase.LeftRootFind) - new(eqs, make_affect(affect), make_affect(affect_neg), rootfind) + reinitializealg::SciMLBase.DAEInitializationAlgorithm + function SymbolicContinuousCallback(; + eqs::Vector{Equation}, + affect = NULL_AFFECT, + affect_neg = affect, + rootfind = SciMLBase.LeftRootFind, + reinitializealg = SciMLBase.CheckInit()) + new(eqs, make_affect(affect), make_affect(affect_neg), rootfind, reinitializealg) end # Default affect to nothing end make_affect(affect) = affect @@ -183,6 +193,12 @@ function affect_negs(cbs::Vector{SymbolicContinuousCallback}) mapreduce(affect_negs, vcat, cbs, init = Equation[]) end +reinitialization_alg(cb::SymbolicContinuousCallback) = cb.reinitializealg +function reinitialization_algs(cbs::Vector{SymbolicContinuousCallback}) + mapreduce( + reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[]) +end + namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af] namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s) namespace_affects(::Nothing, s) = nothing @@ -225,11 +241,13 @@ struct SymbolicDiscreteCallback # TODO: Iterative condition::Any affects::Any + reinitializealg::SciMLBase.DAEInitializationAlgorithm - function SymbolicDiscreteCallback(condition, affects = NULL_AFFECT) + function SymbolicDiscreteCallback( + condition, affects = NULL_AFFECT, reinitializealg = SciMLBase.CheckInit()) c = scalarize_condition(condition) a = scalarize_affects(affects) - new(c, a) + new(c, a, reinitializealg) end # Default affect to nothing end @@ -286,6 +304,12 @@ function affects(cbs::Vector{SymbolicDiscreteCallback}) reduce(vcat, affects(cb) for cb in cbs; init = []) end +reinitialization_alg(cb::SymbolicDiscreteCallback) = cb.reinitializealg +function reinitialization_algs(cbs::Vector{SymbolicDiscreteCallback}) + mapreduce( + reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[]) +end + function namespace_callback(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCallback af = affects(cb) af = af isa AbstractVector ? namespace_affect.(af, Ref(s)) : namespace_affect(af, s) @@ -579,13 +603,15 @@ function generate_single_rootfinding_callback( initfn = SciMLBase.INITIALIZE_DEFAULT end return ContinuousCallback( - cond, affect_function.affect, affect_function.affect_neg, - rootfind = cb.rootfind, initialize = initfn) + cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind, + initialize = initfn, + initializealg = reinitialization_alg(cb)) end function generate_vector_rootfinding_callback( cbs, sys::AbstractODESystem, dvs = unknowns(sys), - ps = parameters(sys); rootfind = SciMLBase.RightRootFind, kwargs...) + ps = parameters(sys); rootfind = SciMLBase.RightRootFind, + reinitialization = SciMLBase.CheckInit(), kwargs...) eqs = map(cb -> flatten_equations(cb.eqs), cbs) num_eqs = length.(eqs) # fuse equations to create VectorContinuousCallback @@ -650,7 +676,8 @@ function generate_vector_rootfinding_callback( initfn = SciMLBase.INITIALIZE_DEFAULT end return VectorContinuousCallback( - cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initfn) + cond, affect, affect_neg, length(eqs), rootfind = rootfind, + initialize = initfn, initializealg = reinitialization) end """ @@ -690,10 +717,15 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow # group the cbs by what rootfind op they use # groupby would be very useful here, but alas cb_classes = Dict{ - @NamedTuple{rootfind::SciMLBase.RootfindOpt}, Vector{SymbolicContinuousCallback}}() + @NamedTuple{ + rootfind::SciMLBase.RootfindOpt, + reinitialization::SciMLBase.DAEInitializationAlgorithm}, Vector{SymbolicContinuousCallback}}() for cb in cbs push!( - get!(() -> SymbolicContinuousCallback[], cb_classes, (rootfind = cb.rootfind,)), + get!(() -> SymbolicContinuousCallback[], cb_classes, + ( + rootfind = cb.rootfind, + reinitialization = reinitialization_alg(cb))), cb) end @@ -701,7 +733,8 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow compiled_callbacks = map(collect(pairs(sort!( OrderedDict(cb_classes); by = p -> p.rootfind)))) do (equiv_class, cbs_in_class) return generate_vector_rootfinding_callback( - cbs_in_class, sys, dvs, ps; rootfind = equiv_class.rootfind, kwargs...) + cbs_in_class, sys, dvs, ps; rootfind = equiv_class.rootfind, + reinitialization = equiv_class.reinitialization, kwargs...) end if length(compiled_callbacks) == 1 return compiled_callbacks[] @@ -772,10 +805,12 @@ function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = no end if cond isa AbstractVector # Preset Time - return PresetTimeCallback(cond, as; initialize = initfn) + return PresetTimeCallback( + cond, as; initialize = initfn, initializealg = reinitialization_alg(cb)) else # Periodic - return PeriodicCallback(as, cond; initialize = initfn) + return PeriodicCallback( + as, cond; initialize = initfn, initializealg = reinitialization_alg(cb)) end end @@ -800,7 +835,8 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! = else initfn = SciMLBase.INITIALIZE_DEFAULT end - return DiscreteCallback(c, as; initialize = initfn) + return DiscreteCallback( + c, as; initialize = initfn, initializealg = reinitialization_alg(cb)) end end diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index e1d12814ef..26593f980c 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -867,6 +867,88 @@ end @test sign.(cos.(3 * (required_crossings_c2 .+ 1e-6))) == sign.(last.(cr2)) end +@testset "Discrete event reinitialization (#3142)" begin + @connector LiquidPort begin + p(t)::Float64, [description = "Set pressure in bar", + guess = 1.01325] + Vdot(t)::Float64, + [description = "Volume flow rate in L/min", + guess = 0.0, + connect = Flow] + end + + @mtkmodel PressureSource begin + @components begin + port = LiquidPort() + end + @parameters begin + p_set::Float64 = 1.01325, [description = "Set pressure in bar"] + end + @equations begin + port.p ~ p_set + end + end + + @mtkmodel BinaryValve begin + @constants begin + p_ref::Float64 = 1.0, [description = "Reference pressure drop in bar"] + ρ_ref::Float64 = 1000.0, [description = "Reference density in kg/m^3"] + end + @components begin + port_in = LiquidPort() + port_out = LiquidPort() + end + @parameters begin + k_V::Float64 = 1.0, [description = "Valve coefficient in L/min/bar"] + k_leakage::Float64 = 1e-08, [description = "Leakage coefficient in L/min/bar"] + ρ::Float64 = 1000.0, [description = "Density in kg/m^3"] + end + @variables begin + S(t)::Float64, [description = "Valve state", guess = 1.0, irreducible = true] + Δp(t)::Float64, [description = "Pressure difference in bar", guess = 1.0] + Vdot(t)::Float64, [description = "Volume flow rate in L/min", guess = 1.0] + end + @equations begin + # Port handling + port_in.Vdot ~ -Vdot + port_out.Vdot ~ Vdot + Δp ~ port_in.p - port_out.p + # System behavior + D(S) ~ 0.0 + Vdot ~ S * k_V * sign(Δp) * sqrt(abs(Δp) / p_ref * ρ_ref / ρ) + k_leakage * Δp # softplus alpha function to avoid negative values under the sqrt + end + end + + # Test System + @mtkmodel TestSystem begin + @components begin + pressure_source_1 = PressureSource(p_set = 2.0) + binary_valve_1 = BinaryValve(S = 1.0, k_leakage = 0.0) + binary_valve_2 = BinaryValve(S = 1.0, k_leakage = 0.0) + pressure_source_2 = PressureSource(p_set = 1.0) + end + @equations begin + connect(pressure_source_1.port, binary_valve_1.port_in) + connect(binary_valve_1.port_out, binary_valve_2.port_in) + connect(binary_valve_2.port_out, pressure_source_2.port) + end + @discrete_events begin + [30] => [binary_valve_1.S ~ 0.0, binary_valve_2.Δp ~ 0.0] + [60] => [ + binary_valve_1.S ~ 1.0, binary_valve_2.S ~ 0.0, binary_valve_2.Δp ~ 1.0] + [120] => [binary_valve_1.S ~ 0.0, binary_valve_2.Δp ~ 0.0] + end + end + + # Test Simulation + @mtkbuild sys = TestSystem() + + # Test Simulation + prob = ODEProblem(sys, [], (0.0, 150.0)) + sol = solve(prob) + @test sol[end] == [0.0, 0.0, 0.0] +end + @testset "Discrete variable timeseries" begin @variables x(t) @parameters a(t) b(t) c(t) @@ -887,3 +969,22 @@ end @test sol[b] == [2.0, 5.0, 5.0] @test sol[c] == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] end + +@testset "Bump" begin + @variables x(t) [irreducible = true] y(t) [irreducible = true] + eqs = [x ~ y, D(x) ~ -1] + cb = [x ~ 0.0] => [x ~ 0, y ~ 1] + @mtkbuild pend = ODESystem(eqs, t; continuous_events = [cb]) + prob = ODEProblem(pend, [x => 1], (0.0, 3.0), guesses = [y => x]) + @test_throws "CheckInit specified but initialization" solve(prob, Rodas5()) + + cb = [x ~ 0.0] => [y ~ 1] + @mtkbuild pend = ODESystem(eqs, t; continuous_events = [cb]) + prob = ODEProblem(pend, [x => 1], (0.0, 3.0), guesses = [y => x]) + @test_broken !SciMLBase.successful_retcode(solve(prob, Rodas5())) + + cb = [x ~ 0.0] => [x ~ 1, y ~ 1] + @mtkbuild pend = ODESystem(eqs, t; continuous_events = [cb]) + prob = ODEProblem(pend, [x => 1], (0.0, 3.0), guesses = [y => x]) + @test all(≈(0.0; atol = 1e-9), solve(prob, Rodas5())[[x, y]][end]) +end