Skip to content

Commit c1b3cbd

Browse files
committed
fix: add reinitalizealg back
1 parent d6df569 commit c1b3cbd

File tree

5 files changed

+53
-30
lines changed

5 files changed

+53
-30
lines changed

ext/MTKFMIExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ with the name `namespace__variable`.
9393
- `name`: The name of the system.
9494
"""
9595
function MTK.FMIComponent(::Val{Ver}; fmu = nothing, tolerance = 1e-6,
96-
communication_step_size = nothing, type, name) where {Ver}
96+
communication_step_size = nothing, type, name, reinitializealg = nothing) where {Ver}
9797
if Ver != 2 && Ver != 3
9898
throw(ArgumentError("FMI Version must be `2` or `3`"))
9999
end
@@ -238,7 +238,7 @@ function MTK.FMIComponent(::Val{Ver}; fmu = nothing, tolerance = 1e-6,
238238
finalize_affect = MTK.FunctionalAffect(fmiFinalize!, [], [wrapper], [])
239239
step_affect = MTK.FunctionalAffect(Returns(nothing), [], [], [])
240240
instance_management_callback = MTK.SymbolicDiscreteCallback(
241-
(t != t - 1), step_affect; finalize = finalize_affect)
241+
(t != t - 1), step_affect; finalize = finalize_affect, reinitializealg)
242242

243243
push!(params, wrapper)
244244
append!(observed, der_observed)
@@ -279,7 +279,7 @@ function MTK.FMIComponent(::Val{Ver}; fmu = nothing, tolerance = 1e-6,
279279
fmiCSStep!; observed = cb_observed, modified = cb_modified, ctx = _functor)
280280
instance_management_callback = MTK.SymbolicDiscreteCallback(
281281
communication_step_size, step_affect; initialize = initialize_affect,
282-
finalize = finalize_affect)
282+
finalize = finalize_affect, reinitializealg)
283283

284284
# guarded in case there are no outputs/states and the variable is `[]`.
285285
symbolic_type(__mtk_internal_o) == NotSymbolic() || push!(params, __mtk_internal_o)

src/systems/callbacks.jl

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,9 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either:
200200
+ `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem.
201201
* A [`ImperativeAffect`](@ref); refer to its documentation for details.
202202
203-
DAEs will automatically be reinitialized.
203+
`reinitializealg` is used to set how the system will be reinitialized after the callback.
204+
- Symbolic affects have reinitialization built in. In this case the algorithm will default to SciMLBase.NoInit(), and should **not** be provided.
205+
- Functional and imperative affects will default to SciMLBase.CheckInit(), which will error if the system is not properly reinitialized after the callback. If your system is a DAE, pass in an algorithm like SciMLBase.BrownBasicFullInit() to properly re-initialize.
204206
205207
Initial and final affects can also be specified identically to positive and negative edge affects. Initialization affects
206208
will run as soon as the solver starts, while finalization affects will be executed after termination.
@@ -212,6 +214,7 @@ struct SymbolicContinuousCallback <: AbstractCallback
212214
initialize::Union{Affect, Nothing}
213215
finalize::Union{Affect, Nothing}
214216
rootfind::Union{Nothing, SciMLBase.RootfindOpt}
217+
reinitializealg::SciMLBase.DAEInitializationAlgorithm
215218

216219
function SymbolicContinuousCallback(
217220
conditions::Union{Equation, Vector{Equation}},
@@ -221,13 +224,21 @@ struct SymbolicContinuousCallback <: AbstractCallback
221224
initialize = nothing,
222225
finalize = nothing,
223226
rootfind = SciMLBase.LeftRootFind,
227+
reinitializealg = nothing,
224228
iv = nothing,
225229
algeeqs = Equation[])
226230
conditions = (conditions isa AbstractVector) ? conditions : [conditions]
231+
232+
if isnothing(reinitializealg)
233+
any(a -> (a isa FunctionalAffect || a isa ImperativeAffect), [affect, affect_neg, initialize, finalize]) ?
234+
reinitializealg = SciMLBase.CheckInit() :
235+
reinitializealg = SciMLBase.NoInit()
236+
end
237+
227238
new(conditions, make_affect(affect; iv, algeeqs, discrete_parameters),
228239
make_affect(affect_neg; iv, algeeqs, discrete_parameters),
229240
make_affect(initialize; iv, algeeqs, discrete_parameters), make_affect(
230-
finalize; iv, algeeqs, discrete_parameters), rootfind)
241+
finalize; iv, algeeqs, discrete_parameters), rootfind, reinitializealg)
231242
end # Default affect to nothing
232243
end
233244

@@ -424,16 +435,22 @@ struct SymbolicDiscreteCallback <: AbstractCallback
424435
affect::Union{Affect, Nothing}
425436
initialize::Union{Affect, Nothing}
426437
finalize::Union{Affect, Nothing}
438+
reinitializealg::SciMLBase.DAEInitializationAlgorithm
427439

428440
function SymbolicDiscreteCallback(
429441
condition, affect = nothing;
430442
initialize = nothing, finalize = nothing, iv = nothing,
431-
algeeqs = Equation[], discrete_parameters = Any[])
443+
algeeqs = Equation[], discrete_parameters = Any[], reinitializealg = nothing)
432444
c = is_timed_condition(condition) ? condition : value(scalarize(condition))
433445

446+
if isnothing(reinitializealg)
447+
any(a -> (a isa FunctionalAffect || a isa ImperativeAffect), [affect, affect_neg, initialize, finalize]) ?
448+
reinitializealg = SciMLBase.CheckInit() :
449+
reinitializealg = SciMLBase.NoInit()
450+
end
434451
new(c, make_affect(affect; iv, algeeqs, discrete_parameters),
435452
make_affect(initialize; iv, algeeqs, discrete_parameters),
436-
make_affect(finalize; iv, algeeqs, discrete_parameters))
453+
make_affect(finalize; iv, algeeqs, discrete_parameters), reinitializealg)
437454
end # Default affect to nothing
438455
end
439456

@@ -525,7 +542,8 @@ function Base.hash(cb::AbstractCallback, s::UInt)
525542
!is_discrete(cb) && (s = hash(affect_negs(cb), s))
526543
s = hash(initialize_affects(cb), s)
527544
s = hash(finalize_affects(cb), s)
528-
!is_discrete(cb) ? hash(cb.rootfind, s) : s
545+
!is_discrete(cb) && (s = hash(cb.rootfind, s))
546+
hash(cb.reinitializealg, s)
529547
end
530548

531549
###########################
@@ -562,7 +580,7 @@ end
562580
function Base.:(==)(e1::AbstractCallback, e2::AbstractCallback)
563581
(is_discrete(e1) === is_discrete(e2)) || return false
564582
(isequal(e1.conditions, e2.conditions) && isequal(e1.affect, e2.affect) &&
565-
isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize)) ||
583+
isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize)) && isequal(e1.reinitializealg, e2.reinitializealg) ||
566584
return false
567585
is_discrete(e1) ||
568586
(isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind))
@@ -664,15 +682,15 @@ function generate_continuous_callbacks(sys::AbstractSystem, dvs = unknowns(sys),
664682
ps = parameters(sys; initial_parameters = true); kwargs...)
665683
cbs = continuous_events(sys)
666684
isempty(cbs) && return nothing
667-
cb_classes = Dict{SciMLBase.RootfindOpt, Vector{SymbolicContinuousCallback}}()
685+
cb_classes = Dict{Tuple{SciMLBase.RootfindOpt, SciMLBase.DAEReinitializationAlg}, Vector{SymbolicContinuousCallback}}()
668686

669687
# Sort the callbacks by their rootfinding method
670688
for cb in cbs
671-
_cbs = get!(() -> SymbolicContinuousCallback[], cb_classes, cb.rootfind)
689+
_cbs = get!(() -> SymbolicContinuousCallback[], cb_classes, (cb.rootfind, cb.reinitializealg))
672690
push!(_cbs, cb)
673691
end
674-
cb_classes = sort!(OrderedDict(cb_classes))
675-
compiled_callbacks = [generate_callback(cb, sys; kwargs...) for (rf, cb) in cb_classes]
692+
sort!(OrderedDict(cb_classes), by = cb -> cb.rootfind)
693+
compiled_callbacks = [generate_callback(cb, sys; kwargs...) for ((rf, reinit), cb) in cb_classes]
676694
if length(compiled_callbacks) == 1
677695
return only(compiled_callbacks)
678696
else
@@ -741,7 +759,7 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
741759

742760
return VectorContinuousCallback(
743761
trigger, affect, affect_neg, length(eqs); initialize, finalize,
744-
rootfind = cbs[1].rootfind, initializealg = SciMLBase.NoInit())
762+
rootfind = cbs[1].rootfind, initializealg = cbs[1].reinitializealg)
745763
end
746764

747765
function generate_callback(cb, sys; kwargs...)
@@ -768,16 +786,16 @@ function generate_callback(cb, sys; kwargs...)
768786
if is_discrete(cb)
769787
if is_timed && conditions(cb) isa AbstractVector
770788
return PresetTimeCallback(trigger, affect; initialize,
771-
finalize, initializealg = SciMLBase.NoInit())
789+
finalize, initializealg = cb.reinitializealg)
772790
elseif is_timed
773-
return PeriodicCallback(affect, trigger; initialize, finalize, initializealg = SciMLBase.NoInit())
791+
return PeriodicCallback(affect, trigger; initialize, finalize, initializealg = cb.reinitializealg)
774792
else
775793
return DiscreteCallback(trigger, affect; initialize,
776-
finalize, initializealg = SciMLBase.NoInit())
794+
finalize, initializealg = cb.reinitializealg)
777795
end
778796
else
779797
return ContinuousCallback(trigger, affect, affect_neg; initialize, finalize,
780-
rootfind = cb.rootfind, initializealg = SciMLBase.NoInit())
798+
rootfind = cb.rootfind, initializealg = cb.reinitializealg)
781799
end
782800
end
783801

src/systems/model_parsing.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ function _model_macro(mod, name, expr, isconnector)
6464
push!(exprs.args, :(systems = ODESystem[]))
6565
push!(exprs.args, :(equations = Union{Equation, Vector{Equation}}[]))
6666
push!(exprs.args, :(defaults = Dict{Num, Union{Number, Symbol, Function}}()))
67-
push!(exprs.args, :(disc_events = []))
68-
push!(exprs.args, :(cont_events = []))
6967

7068
Base.remove_linenums!(expr)
7169
for arg in expr.args
@@ -107,8 +105,6 @@ function _model_macro(mod, name, expr, isconnector)
107105
push!(exprs.args, :(push!(parameters, $(ps...))))
108106
push!(exprs.args, :(push!(systems, $(comps...))))
109107
push!(exprs.args, :(push!(variables, $(vs...))))
110-
push!(exprs.args, :(push!(disc_events, $(d_evts...))))
111-
push!(exprs.args, :(push!(cont_events, $(c_evts...))))
112108

113109
gui_metadata = isassigned(icon) > 0 ? GUIMetadata(GlobalRef(mod, name), icon[]) :
114110
GUIMetadata(GlobalRef(mod, name))
@@ -120,7 +116,7 @@ function _model_macro(mod, name, expr, isconnector)
120116

121117
sys = :($ODESystem($(flatten_equations)(equations), $iv, variables, parameters;
122118
name, description = $description, systems, gui_metadata = $gui_metadata,
123-
defaults, continuous_events = cont_events, discrete_events = disc_events))
119+
defaults))
124120

125121
if length(ext) == 0
126122
push!(exprs.args, :(var"#___sys___" = $sys))
@@ -131,6 +127,18 @@ function _model_macro(mod, name, expr, isconnector)
131127
isconnector && push!(exprs.args,
132128
:($Setfield.@set!(var"#___sys___".connector_type=$connector_type(var"#___sys___"))))
133129

130+
!isempty(c_evts) && push!(exprs.args,
131+
:($Setfield.@set!(var"#___sys___".continuous_events=$SymbolicContinuousCallback.([
132+
$(c_evts...)
133+
]))))
134+
135+
@show d_evts
136+
!isempty(d_evts) && push!(exprs.args,
137+
:($Setfield.@set!(var"#___sys___".discrete_events=$SymbolicDiscreteCallback.([
138+
$(d_evts...)
139+
]))))
140+
141+
134142
f = if length(where_types) == 0
135143
:($(Symbol(:__, name, :__))(; name, $(kwargs...)) = $exprs)
136144
else

test/fmi/fmi.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ end
157157
@testset "v2, CS" begin
158158
fmu = loadFMU(joinpath(FMU_DIR, "SimpleAdder.fmu"); type = :CS)
159159
@named adder = MTK.FMIComponent(
160-
Val(2); fmu, type = :CS, communication_step_size = 1e-6)
160+
Val(2); fmu, type = :CS, communication_step_size = 1e-6, reinitializealg = BrownFullBasicInit())
161161
@test MTK.isinput(adder.a)
162162
@test MTK.isinput(adder.b)
163163
@test MTK.isoutput(adder.out)
@@ -209,7 +209,7 @@ end
209209
@testset "v3, CS" begin
210210
fmu = loadFMU(joinpath(FMU_DIR, "StateSpace.fmu"); type = :CS)
211211
@named sspace = MTK.FMIComponent(
212-
Val(3); fmu, communication_step_size = 1e-6, type = :CS)
212+
Val(3); fmu, communication_step_size = 1e-6, type = :CS, reinitializealg = BrownFullBasicInit())
213213
@test MTK.isinput(sspace.u)
214214
@test MTK.isoutput(sspace.y)
215215
@test !MTK.isinput(sspace.x) && !MTK.isoutput(sspace.x)

test/symbolic_events.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,7 +1204,7 @@ end
12041204
@mtkmodel DECAY begin
12051205
@parameters begin
12061206
unrelated[1:2] = zeros(2)
1207-
k = 0.0
1207+
k(t) = 0.0
12081208
end
12091209
@variables begin
12101210
x(t) = 10.0
@@ -1213,7 +1213,7 @@ end
12131213
D(x) ~ -k * x
12141214
end
12151215
@discrete_events begin
1216-
(t == 1.0) => [k ~ 1.0]
1216+
(t == 1.0) => [k ~ 1.0], discrete_parameters = [k]
12171217
end
12181218
end
12191219
@mtkbuild decay = DECAY()
@@ -1338,7 +1338,4 @@ end
13381338
sol = solve(prob, FBDF())
13391339
@test prob.ps[g] == sol.ps[g]
13401340
end
1341-
# TODO: test:
1342-
# - Functional affects reinitialize correctly
13431341
# - explicit equation of t in a functional affect
1344-
# - reinitialization after affects

0 commit comments

Comments
 (0)