@@ -328,13 +328,16 @@ struct SymbolicDiscreteCallback
328328 # TODO : Iterative
329329 condition:: Any
330330 affects:: Any
331+ initialize:: Any
332+ finalize:: Any
331333 reinitializealg:: SciMLBase.DAEInitializationAlgorithm
332334
333335 function SymbolicDiscreteCallback (
334- condition, affects = NULL_AFFECT, reinitializealg = SciMLBase. CheckInit ())
336+ condition, affects = NULL_AFFECT; reinitializealg = SciMLBase. CheckInit (),
337+ initialize= NULL_AFFECT, finalize= NULL_AFFECT)
335338 c = scalarize_condition (condition)
336339 a = scalarize_affects (affects)
337- new (c, a, reinitializealg)
340+ new (c, a, scalarize_affects (initialize), scalarize_affects (finalize), reinitializealg)
338341 end # Default affect to nothing
339342end
340343
@@ -373,11 +376,16 @@ function Base.show(io::IO, db::SymbolicDiscreteCallback)
373376end
374377
375378function Base.:(== )(e1:: SymbolicDiscreteCallback , e2:: SymbolicDiscreteCallback )
376- isequal (e1. condition, e2. condition) && isequal (e1. affects, e2. affects)
379+ isequal (e1. condition, e2. condition) && isequal (e1. affects, e2. affects) &&
380+ isequal (e1. initialize, e2. initialize) && isequal (e1. finalize, e2. finalize)
377381end
378382function Base. hash (cb:: SymbolicDiscreteCallback , s:: UInt )
379383 s = hash (cb. condition, s)
380- cb. affects isa AbstractVector ? foldr (hash, cb. affects, init = s) : hash (cb. affects, s)
384+ s = cb. affects isa AbstractVector ? foldr (hash, cb. affects, init = s) : hash (cb. affects, s)
385+ s = cb. initialize isa AbstractVector ? foldr (hash, cb. initialize, init = s) : hash (cb. initialize, s)
386+ s = cb. finalize isa AbstractVector ? foldr (hash, cb. finalize, init = s) : hash (cb. finalize, s)
387+ s = hash (cb. reinitializealg, s)
388+ return s
381389end
382390
383391condition (cb:: SymbolicDiscreteCallback ) = cb. condition
@@ -397,10 +405,23 @@ function reinitialization_algs(cbs::Vector{SymbolicDiscreteCallback})
397405 reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
398406end
399407
408+
409+ initialize_affects (cb:: SymbolicDiscreteCallback ) = cb. initialize
410+ function initialize_affects (cbs:: Vector{SymbolicDiscreteCallback} )
411+ mapreduce (initialize_affects, vcat, cbs, init = Equation[])
412+ end
413+
414+ finalize_affects (cb:: SymbolicDiscreteCallback ) = cb. finalize
415+ function finalize_affects (cbs:: Vector{SymbolicDiscreteCallback} )
416+ mapreduce (finalize_affects, vcat, cbs, init = Equation[])
417+ end
418+
400419function namespace_callback (cb:: SymbolicDiscreteCallback , s):: SymbolicDiscreteCallback
401- af = affects (cb)
402- af = af isa AbstractVector ? namespace_affect .(af, Ref (s)) : namespace_affect (af, s)
403- SymbolicDiscreteCallback (namespace_condition (condition (cb), s), af)
420+ function namespace_affects (af)
421+ return af isa AbstractVector ? namespace_affect .(af, Ref (s)) : namespace_affect (af, s)
422+ end
423+ SymbolicDiscreteCallback (namespace_condition (condition (cb), s), namespace_affects (affects (cb)),
424+ reinitializealg= cb. reinitializealg, initialize= namespace_affects (initialize_affects (cb)), finalize= namespace_affects (finalize_affects (cb)))
404425end
405426
406427SymbolicDiscreteCallbacks (cb:: Pair ) = SymbolicDiscreteCallback[SymbolicDiscreteCallback (cb)]
@@ -773,10 +794,10 @@ function generate_vector_rootfinding_callback(
773794 let save_idxs = save_idxs
774795 if ! isnothing (fn. initialize)
775796 (i) -> begin
797+ fn. initialize (i)
776798 for idx in save_idxs
777799 SciMLBase. save_discretes! (i, idx)
778800 end
779- fn. initialize (i)
780801 end
781802 else
782803 (i) -> begin
@@ -809,20 +830,13 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
809830 eq_aff = affects (cb)
810831 eq_neg_aff = affect_negs (cb)
811832 affect = compile_affect (eq_aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
812- function compile_optional_affect (aff, default = nothing )
813- if isnothing (aff) || aff == default
814- return nothing
815- else
816- return compile_affect (aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
817- end
818- end
819833 if eq_neg_aff === eq_aff
820834 affect_neg = affect
821835 else
822- affect_neg = compile_optional_affect ( eq_neg_aff)
836+ affect_neg = _compile_optional_affect (NULL_AFFECT, eq_neg_aff, cb, sys, dvs, ps; kwargs ... )
823837 end
824- initialize = compile_optional_affect ( initialize_affects (cb), NULL_AFFECT )
825- finalize = compile_optional_affect ( finalize_affects (cb), NULL_AFFECT )
838+ initialize = _compile_optional_affect (NULL_AFFECT, initialize_affects (cb), cb, sys, dvs, ps; kwargs ... )
839+ finalize = _compile_optional_affect (NULL_AFFECT, finalize_affects (cb), cb, sys, dvs, ps; kwargs ... )
826840 (affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize)
827841end
828842
@@ -914,31 +928,48 @@ function compile_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs...)
914928 compile_user_affect (affect, cb, sys, dvs, ps; kwargs... )
915929end
916930
931+
932+ function _compile_optional_affect (default, aff, cb, sys, dvs, ps; kwargs... )
933+ if isnothing (aff) || aff == default
934+ return nothing
935+ else
936+ return compile_affect (aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
937+ end
938+ end
917939function generate_timed_callback (cb, sys, dvs, ps; postprocess_affect_expr! = nothing ,
918940 kwargs... )
919941 cond = condition (cb)
920942 as = compile_affect (affects (cb), cb, sys, dvs, ps; expression = Val{false },
921943 postprocess_affect_expr!, kwargs... )
944+
945+ user_initfun = _compile_optional_affect (NULL_AFFECT, initialize_affects (cb), cb, sys, dvs, ps; kwargs... )
946+ user_finfun = _compile_optional_affect (NULL_AFFECT, finalize_affects (cb), cb, sys, dvs, ps; kwargs... )
922947 if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing &&
923948 (save_idxs = get (ic. callback_to_clocks, cb, nothing )) != = nothing
924- initfn = let save_idxs = save_idxs
949+ initfn = let
950+ save_idxs = save_idxs
951+ initfun= user_initfun
925952 function (cb, u, t, integrator)
953+ if ! isnothing (initfun)
954+ initfun (integrator)
955+ end
926956 for idx in save_idxs
927957 SciMLBase. save_discretes! (integrator, idx)
928958 end
929959 end
930960 end
931961 else
932- initfn = SciMLBase. INITIALIZE_DEFAULT
962+ initfn = isnothing (user_initfun) ? SciMLBase. INITIALIZE_DEFAULT : (_,_,_,i) -> user_initfun (i)
933963 end
964+ finfun = isnothing (user_finfun) ? SciMLBase. FINALIZE_DEFAULT : (_,_,_,i) -> user_finfun (i)
934965 if cond isa AbstractVector
935966 # Preset Time
936967 return PresetTimeCallback (
937- cond, as; initialize = initfn, initializealg = reinitialization_alg (cb))
968+ cond, as; initialize = initfn, finalize = finfun, initializealg = reinitialization_alg (cb))
938969 else
939970 # Periodic
940971 return PeriodicCallback (
941- as, cond; initialize = initfn, initializealg = reinitialization_alg (cb))
972+ as, cond; initialize = initfn, finalize = finfun, initializealg = reinitialization_alg (cb))
942973 end
943974end
944975
@@ -951,20 +982,27 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
951982 c = compile_condition (cb, sys, dvs, ps; expression = Val{false }, kwargs... )
952983 as = compile_affect (affects (cb), cb, sys, dvs, ps; expression = Val{false },
953984 postprocess_affect_expr!, kwargs... )
985+
986+ user_initfun = _compile_optional_affect (NULL_AFFECT, initialize_affects (cb), cb, sys, dvs, ps; kwargs... )
987+ user_finfun = _compile_optional_affect (NULL_AFFECT, finalize_affects (cb), cb, sys, dvs, ps; kwargs... )
954988 if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing &&
955989 (save_idxs = get (ic. callback_to_clocks, cb, nothing )) != = nothing
956- initfn = let save_idxs = save_idxs
990+ initfn = let save_idxs = save_idxs, initfun = user_initfun
957991 function (cb, u, t, integrator)
992+ if ! isnothing (initfun)
993+ initfun (integrator)
994+ end
958995 for idx in save_idxs
959996 SciMLBase. save_discretes! (integrator, idx)
960997 end
961998 end
962999 end
9631000 else
964- initfn = SciMLBase. INITIALIZE_DEFAULT
1001+ initfn = isnothing (user_initfun) ? SciMLBase. INITIALIZE_DEFAULT : (_,_,_,i) -> user_initfun (i)
9651002 end
1003+ finfun = isnothing (user_finfun) ? SciMLBase. FINALIZE_DEFAULT : (_,_,_,i) -> user_finfun (i)
9661004 return DiscreteCallback (
967- c, as; initialize = initfn, initializealg = reinitialization_alg (cb))
1005+ c, as; initialize = initfn, finalize = finfun, initializealg = reinitialization_alg (cb))
9681006 end
9691007end
9701008
0 commit comments