@@ -328,13 +328,16 @@ struct SymbolicDiscreteCallback
328
328
# TODO : Iterative
329
329
condition:: Any
330
330
affects:: Any
331
+ initialize:: Any
332
+ finalize:: Any
331
333
reinitializealg:: SciMLBase.DAEInitializationAlgorithm
332
334
333
335
function SymbolicDiscreteCallback (
334
- condition, affects = NULL_AFFECT, reinitializealg = SciMLBase. CheckInit ())
336
+ condition, affects = NULL_AFFECT; reinitializealg = SciMLBase. CheckInit (),
337
+ initialize= NULL_AFFECT, finalize= NULL_AFFECT)
335
338
c = scalarize_condition (condition)
336
339
a = scalarize_affects (affects)
337
- new (c, a, reinitializealg)
340
+ new (c, a, scalarize_affects (initialize), scalarize_affects (finalize), reinitializealg)
338
341
end # Default affect to nothing
339
342
end
340
343
@@ -373,11 +376,16 @@ function Base.show(io::IO, db::SymbolicDiscreteCallback)
373
376
end
374
377
375
378
function Base.:(== )(e1:: SymbolicDiscreteCallback , e2:: SymbolicDiscreteCallback )
376
- isequal (e1. condition, e2. condition) && isequal (e1. affects, e2. affects)
379
+ isequal (e1. condition, e2. condition) && isequal (e1. affects, e2. affects) &&
380
+ isequal (e1. initialize, e2. initialize) && isequal (e1. finalize, e2. finalize)
377
381
end
378
382
function Base. hash (cb:: SymbolicDiscreteCallback , s:: UInt )
379
383
s = hash (cb. condition, s)
380
- cb. affects isa AbstractVector ? foldr (hash, cb. affects, init = s) : hash (cb. affects, s)
384
+ s = cb. affects isa AbstractVector ? foldr (hash, cb. affects, init = s) : hash (cb. affects, s)
385
+ s = cb. initialize isa AbstractVector ? foldr (hash, cb. initialize, init = s) : hash (cb. initialize, s)
386
+ s = cb. finalize isa AbstractVector ? foldr (hash, cb. finalize, init = s) : hash (cb. finalize, s)
387
+ s = hash (cb. reinitializealg, s)
388
+ return s
381
389
end
382
390
383
391
condition (cb:: SymbolicDiscreteCallback ) = cb. condition
@@ -397,10 +405,23 @@ function reinitialization_algs(cbs::Vector{SymbolicDiscreteCallback})
397
405
reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
398
406
end
399
407
408
+
409
+ initialize_affects (cb:: SymbolicDiscreteCallback ) = cb. initialize
410
+ function initialize_affects (cbs:: Vector{SymbolicDiscreteCallback} )
411
+ mapreduce (initialize_affects, vcat, cbs, init = Equation[])
412
+ end
413
+
414
+ finalize_affects (cb:: SymbolicDiscreteCallback ) = cb. finalize
415
+ function finalize_affects (cbs:: Vector{SymbolicDiscreteCallback} )
416
+ mapreduce (finalize_affects, vcat, cbs, init = Equation[])
417
+ end
418
+
400
419
function namespace_callback (cb:: SymbolicDiscreteCallback , s):: SymbolicDiscreteCallback
401
- af = affects (cb)
402
- af = af isa AbstractVector ? namespace_affect .(af, Ref (s)) : namespace_affect (af, s)
403
- SymbolicDiscreteCallback (namespace_condition (condition (cb), s), af)
420
+ function namespace_affects (af)
421
+ return af isa AbstractVector ? namespace_affect .(af, Ref (s)) : namespace_affect (af, s)
422
+ end
423
+ SymbolicDiscreteCallback (namespace_condition (condition (cb), s), namespace_affects (affects (cb)),
424
+ reinitializealg= cb. reinitializealg, initialize= namespace_affects (initialize_affects (cb)), finalize= namespace_affects (finalize_affects (cb)))
404
425
end
405
426
406
427
SymbolicDiscreteCallbacks (cb:: Pair ) = SymbolicDiscreteCallback[SymbolicDiscreteCallback (cb)]
@@ -773,10 +794,10 @@ function generate_vector_rootfinding_callback(
773
794
let save_idxs = save_idxs
774
795
if ! isnothing (fn. initialize)
775
796
(i) -> begin
797
+ fn. initialize (i)
776
798
for idx in save_idxs
777
799
SciMLBase. save_discretes! (i, idx)
778
800
end
779
- fn. initialize (i)
780
801
end
781
802
else
782
803
(i) -> begin
@@ -809,20 +830,13 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
809
830
eq_aff = affects (cb)
810
831
eq_neg_aff = affect_negs (cb)
811
832
affect = compile_affect (eq_aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
812
- function compile_optional_affect (aff, default = nothing )
813
- if isnothing (aff) || aff == default
814
- return nothing
815
- else
816
- return compile_affect (aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
817
- end
818
- end
819
833
if eq_neg_aff === eq_aff
820
834
affect_neg = affect
821
835
else
822
- affect_neg = compile_optional_affect ( eq_neg_aff)
836
+ affect_neg = _compile_optional_affect (NULL_AFFECT, eq_neg_aff, cb, sys, dvs, ps; kwargs ... )
823
837
end
824
- initialize = compile_optional_affect ( initialize_affects (cb), NULL_AFFECT )
825
- finalize = compile_optional_affect ( finalize_affects (cb), NULL_AFFECT )
838
+ initialize = _compile_optional_affect (NULL_AFFECT, initialize_affects (cb), cb, sys, dvs, ps; kwargs ... )
839
+ finalize = _compile_optional_affect (NULL_AFFECT, finalize_affects (cb), cb, sys, dvs, ps; kwargs ... )
826
840
(affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize)
827
841
end
828
842
@@ -914,31 +928,48 @@ function compile_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs...)
914
928
compile_user_affect (affect, cb, sys, dvs, ps; kwargs... )
915
929
end
916
930
931
+
932
+ function _compile_optional_affect (default, aff, cb, sys, dvs, ps; kwargs... )
933
+ if isnothing (aff) || aff == default
934
+ return nothing
935
+ else
936
+ return compile_affect (aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
937
+ end
938
+ end
917
939
function generate_timed_callback (cb, sys, dvs, ps; postprocess_affect_expr! = nothing ,
918
940
kwargs... )
919
941
cond = condition (cb)
920
942
as = compile_affect (affects (cb), cb, sys, dvs, ps; expression = Val{false },
921
943
postprocess_affect_expr!, kwargs... )
944
+
945
+ user_initfun = _compile_optional_affect (NULL_AFFECT, initialize_affects (cb), cb, sys, dvs, ps; kwargs... )
946
+ user_finfun = _compile_optional_affect (NULL_AFFECT, finalize_affects (cb), cb, sys, dvs, ps; kwargs... )
922
947
if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing &&
923
948
(save_idxs = get (ic. callback_to_clocks, cb, nothing )) != = nothing
924
- initfn = let save_idxs = save_idxs
949
+ initfn = let
950
+ save_idxs = save_idxs
951
+ initfun= user_initfun
925
952
function (cb, u, t, integrator)
953
+ if ! isnothing (initfun)
954
+ initfun (integrator)
955
+ end
926
956
for idx in save_idxs
927
957
SciMLBase. save_discretes! (integrator, idx)
928
958
end
929
959
end
930
960
end
931
961
else
932
- initfn = SciMLBase. INITIALIZE_DEFAULT
962
+ initfn = isnothing (user_initfun) ? SciMLBase. INITIALIZE_DEFAULT : (_,_,_,i) -> user_initfun (i)
933
963
end
964
+ finfun = isnothing (user_finfun) ? SciMLBase. FINALIZE_DEFAULT : (_,_,_,i) -> user_finfun (i)
934
965
if cond isa AbstractVector
935
966
# Preset Time
936
967
return PresetTimeCallback (
937
- cond, as; initialize = initfn, initializealg = reinitialization_alg (cb))
968
+ cond, as; initialize = initfn, finalize = finfun, initializealg = reinitialization_alg (cb))
938
969
else
939
970
# Periodic
940
971
return PeriodicCallback (
941
- as, cond; initialize = initfn, initializealg = reinitialization_alg (cb))
972
+ as, cond; initialize = initfn, finalize = finfun, initializealg = reinitialization_alg (cb))
942
973
end
943
974
end
944
975
@@ -951,20 +982,27 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
951
982
c = compile_condition (cb, sys, dvs, ps; expression = Val{false }, kwargs... )
952
983
as = compile_affect (affects (cb), cb, sys, dvs, ps; expression = Val{false },
953
984
postprocess_affect_expr!, kwargs... )
985
+
986
+ user_initfun = _compile_optional_affect (NULL_AFFECT, initialize_affects (cb), cb, sys, dvs, ps; kwargs... )
987
+ user_finfun = _compile_optional_affect (NULL_AFFECT, finalize_affects (cb), cb, sys, dvs, ps; kwargs... )
954
988
if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing &&
955
989
(save_idxs = get (ic. callback_to_clocks, cb, nothing )) != = nothing
956
- initfn = let save_idxs = save_idxs
990
+ initfn = let save_idxs = save_idxs, initfun = user_initfun
957
991
function (cb, u, t, integrator)
992
+ if ! isnothing (initfun)
993
+ initfun (integrator)
994
+ end
958
995
for idx in save_idxs
959
996
SciMLBase. save_discretes! (integrator, idx)
960
997
end
961
998
end
962
999
end
963
1000
else
964
- initfn = SciMLBase. INITIALIZE_DEFAULT
1001
+ initfn = isnothing (user_initfun) ? SciMLBase. INITIALIZE_DEFAULT : (_,_,_,i) -> user_initfun (i)
965
1002
end
1003
+ finfun = isnothing (user_finfun) ? SciMLBase. FINALIZE_DEFAULT : (_,_,_,i) -> user_finfun (i)
966
1004
return DiscreteCallback (
967
- c, as; initialize = initfn, initializealg = reinitialization_alg (cb))
1005
+ c, as; initialize = initfn, finalize = finfun, initializealg = reinitialization_alg (cb))
968
1006
end
969
1007
end
970
1008
0 commit comments