Skip to content

Commit f082e1a

Browse files
fix: use improved discrete saving API
1 parent 68f5e73 commit f082e1a

File tree

1 file changed

+31
-39
lines changed

1 file changed

+31
-39
lines changed

src/systems/callbacks.jl

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -716,12 +716,18 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
716716
return generate_callback(cbs[cb_ind], sys; kwargs...)
717717
end
718718

719+
if is_split(sys)
720+
ic = get_index_cache(sys)
721+
else
722+
ic = nothing
723+
end
719724
trigger = compile_condition(
720725
cbs, sys, unknowns(sys), parameters(sys; initial_parameters = true); kwargs...)
721726
affects = []
722727
affect_negs = []
723728
inits = []
724729
finals = []
730+
discrete_save_idxs = Vector{Int}[]
725731
for cb in cbs
726732
affect = compile_affect(cb.affect, cb, sys; default = EMPTY_AFFECT, kwargs...)
727733
push!(affects, affect)
@@ -731,8 +737,15 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
731737
push!(affect_negs, affect_neg)
732738
push!(inits,
733739
compile_affect(
734-
cb.initialize, cb, sys; default = nothing, is_init = true, kwargs...))
740+
cb.initialize, cb, sys; default = nothing, kwargs...))
735741
push!(finals, compile_affect(cb.finalize, cb, sys; default = nothing, kwargs...))
742+
743+
if ic !== nothing
744+
save_idxs = get(ic.callback_to_clocks, cb, Int[])
745+
for _ in conditions(cb)
746+
push!(discrete_save_idxs, save_idxs)
747+
end
748+
end
736749
end
737750

738751
# Since there may be different number of conditions and affects,
@@ -758,7 +771,8 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
758771

759772
return VectorContinuousCallback(
760773
trigger, affect, affect_neg, length(eqs); initialize, finalize,
761-
rootfind = cbs[1].rootfind, initializealg = cbs[1].reinitializealg)
774+
rootfind = cbs[1].rootfind, initializealg = cbs[1].reinitializealg,
775+
discrete_save_idxs)
762776
end
763777

764778
function generate_callback(cb, sys; kwargs...)
@@ -775,27 +789,33 @@ function generate_callback(cb, sys; kwargs...)
775789
compile_affect(cb.affect_neg, cb, sys; default = EMPTY_AFFECT, kwargs...)
776790
end
777791
init = compile_affect(cb.initialize, cb, sys; default = SciMLBase.INITIALIZE_DEFAULT,
778-
is_init = true, kwargs...)
792+
kwargs...)
779793
final = compile_affect(
780794
cb.finalize, cb, sys; default = SciMLBase.FINALIZE_DEFAULT, kwargs...)
781795

782796
initialize = isnothing(cb.initialize) ? init : ((c, u, t, i) -> init(i))
783797
finalize = isnothing(cb.finalize) ? final : ((c, u, t, i) -> final(i))
784798

799+
discrete_save_idxs = if is_split(sys)
800+
get(get_index_cache(sys).callback_to_clocks, cb, ())
801+
else
802+
()
803+
end
785804
if is_discrete(cb)
786805
if is_timed && conditions(cb) isa AbstractVector
787806
return PresetTimeCallback(trigger, affect; initialize,
788-
finalize, initializealg = cb.reinitializealg)
807+
finalize, initializealg = cb.reinitializealg, discrete_save_idxs)
789808
elseif is_timed
790809
return PeriodicCallback(
791-
affect, trigger; initialize, finalize, initializealg = cb.reinitializealg)
810+
affect, trigger; initialize, finalize, initializealg = cb.reinitializealg,
811+
discrete_save_idxs)
792812
else
793813
return DiscreteCallback(trigger, affect; initialize,
794-
finalize, initializealg = cb.reinitializealg)
814+
finalize, initializealg = cb.reinitializealg, discrete_save_idxs)
795815
end
796816
else
797817
return ContinuousCallback(trigger, affect, affect_neg; initialize, finalize,
798-
rootfind = cb.rootfind, initializealg = cb.reinitializealg)
818+
rootfind = cb.rootfind, initializealg = cb.reinitializealg, discrete_save_idxs)
799819
end
800820
end
801821

@@ -810,41 +830,13 @@ Notes
810830
"""
811831
function compile_affect(
812832
aff::Union{Nothing, Affect}, cb::AbstractCallback, sys::AbstractSystem;
813-
default = nothing, is_init = false, kwargs...)
814-
save_idxs = if !(has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing)
815-
Int[]
816-
else
817-
get(ic.callback_to_clocks, cb, Int[])
818-
end
819-
833+
default = nothing, kwargs...)
820834
if isnothing(aff)
821-
is_init ? wrap_save_discretes(default, save_idxs) : default
835+
default
822836
elseif aff isa AffectSystem
823-
f = compile_equational_affect(aff, sys; kwargs...)
824-
wrap_save_discretes(f, save_idxs)
837+
compile_equational_affect(aff, sys; kwargs...)
825838
elseif aff isa ImperativeAffect
826-
f = compile_functional_affect(aff, sys; kwargs...)
827-
wrap_save_discretes(f, save_idxs)
828-
end
829-
end
830-
831-
function wrap_save_discretes(f, save_idxs)
832-
let save_idxs = save_idxs, f = f
833-
if f === SciMLBase.INITIALIZE_DEFAULT
834-
(c, u, t, i) -> begin
835-
f(c, u, t, i)
836-
for idx in save_idxs
837-
SciMLBase.save_discretes!(i, idx)
838-
end
839-
end
840-
else
841-
(i) -> begin
842-
isnothing(f) || f(i)
843-
for idx in save_idxs
844-
SciMLBase.save_discretes!(i, idx)
845-
end
846-
end
847-
end
839+
compile_functional_affect(aff, sys; kwargs...)
848840
end
849841
end
850842

0 commit comments

Comments
 (0)