Skip to content

Commit e35c935

Browse files
fix: use improved discrete saving API
1 parent f51778d commit e35c935

File tree

1 file changed

+31
-39
lines changed

1 file changed

+31
-39
lines changed

src/systems/callbacks.jl

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -697,12 +697,18 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
697697
return generate_callback(cbs[cb_ind], sys; kwargs...)
698698
end
699699

700+
if is_split(sys)
701+
ic = get_index_cache(sys)
702+
else
703+
ic = nothing
704+
end
700705
trigger = compile_condition(
701706
cbs, sys, unknowns(sys), parameters(sys; initial_parameters = true); kwargs...)
702707
affects = []
703708
affect_negs = []
704709
inits = []
705710
finals = []
711+
discrete_save_idxs = Vector{Int}[]
706712
for cb in cbs
707713
affect = compile_affect(cb.affect, cb, sys; default = EMPTY_AFFECT, kwargs...)
708714
push!(affects, affect)
@@ -712,8 +718,15 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
712718
push!(affect_negs, affect_neg)
713719
push!(inits,
714720
compile_affect(
715-
cb.initialize, cb, sys; default = nothing, is_init = true, kwargs...))
721+
cb.initialize, cb, sys; default = nothing, kwargs...))
716722
push!(finals, compile_affect(cb.finalize, cb, sys; default = nothing, kwargs...))
723+
724+
if ic !== nothing
725+
save_idxs = get(ic.callback_to_clocks, cb, Int[])
726+
for _ in conditions(cb)
727+
push!(discrete_save_idxs, save_idxs)
728+
end
729+
end
717730
end
718731

719732
# Since there may be different number of conditions and affects,
@@ -739,7 +752,8 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
739752

740753
return VectorContinuousCallback(
741754
trigger, affect, affect_neg, length(eqs); initialize, finalize,
742-
rootfind = cbs[1].rootfind, initializealg = cbs[1].reinitializealg)
755+
rootfind = cbs[1].rootfind, initializealg = cbs[1].reinitializealg,
756+
discrete_save_idxs)
743757
end
744758

745759
function generate_callback(cb, sys; kwargs...)
@@ -756,27 +770,33 @@ function generate_callback(cb, sys; kwargs...)
756770
compile_affect(cb.affect_neg, cb, sys; default = EMPTY_AFFECT, kwargs...)
757771
end
758772
init = compile_affect(cb.initialize, cb, sys; default = SciMLBase.INITIALIZE_DEFAULT,
759-
is_init = true, kwargs...)
773+
kwargs...)
760774
final = compile_affect(
761775
cb.finalize, cb, sys; default = SciMLBase.FINALIZE_DEFAULT, kwargs...)
762776

763777
initialize = isnothing(cb.initialize) ? init : ((c, u, t, i) -> init(i))
764778
finalize = isnothing(cb.finalize) ? final : ((c, u, t, i) -> final(i))
765779

780+
discrete_save_idxs = if is_split(sys)
781+
get(get_index_cache(sys).callback_to_clocks, cb, ())
782+
else
783+
()
784+
end
766785
if is_discrete(cb)
767786
if is_timed && conditions(cb) isa AbstractVector
768787
return PresetTimeCallback(trigger, affect; initialize,
769-
finalize, initializealg = cb.reinitializealg)
788+
finalize, initializealg = cb.reinitializealg, discrete_save_idxs)
770789
elseif is_timed
771790
return PeriodicCallback(
772-
affect, trigger; initialize, finalize, initializealg = cb.reinitializealg)
791+
affect, trigger; initialize, finalize, initializealg = cb.reinitializealg,
792+
discrete_save_idxs)
773793
else
774794
return DiscreteCallback(trigger, affect; initialize,
775-
finalize, initializealg = cb.reinitializealg)
795+
finalize, initializealg = cb.reinitializealg, discrete_save_idxs)
776796
end
777797
else
778798
return ContinuousCallback(trigger, affect, affect_neg; initialize, finalize,
779-
rootfind = cb.rootfind, initializealg = cb.reinitializealg)
799+
rootfind = cb.rootfind, initializealg = cb.reinitializealg, discrete_save_idxs)
780800
end
781801
end
782802

@@ -791,41 +811,13 @@ Notes
791811
"""
792812
function compile_affect(
793813
aff::Union{Nothing, Affect}, cb::AbstractCallback, sys::AbstractSystem;
794-
default = nothing, is_init = false, kwargs...)
795-
save_idxs = if !(has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing)
796-
Int[]
797-
else
798-
get(ic.callback_to_clocks, cb, Int[])
799-
end
800-
814+
default = nothing, kwargs...)
801815
if isnothing(aff)
802-
is_init ? wrap_save_discretes(default, save_idxs) : default
816+
default
803817
elseif aff isa AffectSystem
804-
f = compile_equational_affect(aff, sys; kwargs...)
805-
wrap_save_discretes(f, save_idxs)
818+
compile_equational_affect(aff, sys; kwargs...)
806819
elseif aff isa ImperativeAffect
807-
f = compile_functional_affect(aff, sys; kwargs...)
808-
wrap_save_discretes(f, save_idxs)
809-
end
810-
end
811-
812-
function wrap_save_discretes(f, save_idxs)
813-
let save_idxs = save_idxs, f = f
814-
if f === SciMLBase.INITIALIZE_DEFAULT
815-
(c, u, t, i) -> begin
816-
f(c, u, t, i)
817-
for idx in save_idxs
818-
SciMLBase.save_discretes!(i, idx)
819-
end
820-
end
821-
else
822-
(i) -> begin
823-
isnothing(f) || f(i)
824-
for idx in save_idxs
825-
SciMLBase.save_discretes!(i, idx)
826-
end
827-
end
828-
end
820+
compile_functional_affect(aff, sys; kwargs...)
829821
end
830822
end
831823

0 commit comments

Comments
 (0)