@@ -216,6 +216,11 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either:
216
216
+ `modified_parameters` is a vector of the parameters that are *modified* by `f!`. Note that a parameter will not appear in `p` if it only appears in `modified_parameters`; it must appear in both `parameters` and `modified_parameters` if it is used in the affect definition.
217
217
+ `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem.
218
218
* A [`MutatingFunctionalAffect`](@ref); refer to its documentation for details.
219
+
220
+ Callbacks that impact a DAE are applied, then the DAE is reinitialized using `reinitializealg` (which defaults to `SciMLBase.CheckInit`).
221
+ This reinitialization algorithm ensures that the DAE is satisfied after the callback runs. The default value of `CheckInit` will simply validate
222
+ that the newly-assigned values indeed satisfy the algebraic system; see the documentation on DAE initialization for a more detailed discussion of
223
+ initialization.
219
224
"""
220
225
struct SymbolicContinuousCallback
221
226
eqs:: Vector{Equation}
@@ -224,14 +229,16 @@ struct SymbolicContinuousCallback
224
229
affect:: Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect}
225
230
affect_neg:: Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect, Nothing}
226
231
rootfind:: SciMLBase.RootfindOpt
232
+ reinitializealg:: SciMLBase.DAEInitializationAlgorithm
227
233
function SymbolicContinuousCallback (;
228
234
eqs:: Vector{Equation} ,
229
235
affect = NULL_AFFECT,
230
236
affect_neg = affect,
231
237
rootfind = SciMLBase. LeftRootFind,
232
238
initialize= NULL_AFFECT,
233
- finalize= NULL_AFFECT)
234
- new (eqs, initialize, finalize, make_affect (affect), make_affect (affect_neg), rootfind)
239
+ finalize= NULL_AFFECT,
240
+ reinitializealg= SciMLBase. CheckInit ())
241
+ new (eqs, initialize, finalize, make_affect (affect), make_affect (affect_neg), rootfind, reinitializealg)
235
242
end # Default affect to nothing
236
243
end
237
244
make_affect (affect) = affect
@@ -373,6 +380,10 @@ function finalize_affects(cbs::Vector{SymbolicContinuousCallback})
373
380
mapreduce (finalize_affects, vcat, cbs, init = Equation[])
374
381
end
375
382
383
+ reinitialization_alg (cb:: SymbolicContinuousCallback ) = cb. reinitializealg
384
+ reinitialization_algs (cbs:: Vector{SymbolicContinuousCallback} ) =
385
+ mapreduce (reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
386
+
376
387
namespace_affects (af:: Vector , s) = Equation[namespace_affect (a, s) for a in af]
377
388
namespace_affects (af:: FunctionalAffect , s) = namespace_affect (af, s)
378
389
namespace_affects (af:: MutatingFunctionalAffect , s) = namespace_affect (af, s)
@@ -419,11 +430,12 @@ struct SymbolicDiscreteCallback
419
430
# TODO : Iterative
420
431
condition:: Any
421
432
affects:: Any
433
+ reinitializealg:: SciMLBase.DAEInitializationAlgorithm
422
434
423
- function SymbolicDiscreteCallback (condition, affects = NULL_AFFECT)
435
+ function SymbolicDiscreteCallback (condition, affects = NULL_AFFECT, reinitializealg = SciMLBase . CheckInit () )
424
436
c = scalarize_condition (condition)
425
437
a = scalarize_affects (affects)
426
- new (c, a)
438
+ new (c, a, reinitializealg )
427
439
end # Default affect to nothing
428
440
end
429
441
@@ -481,6 +493,10 @@ function affects(cbs::Vector{SymbolicDiscreteCallback})
481
493
reduce (vcat, affects (cb) for cb in cbs; init = [])
482
494
end
483
495
496
+ reinitialization_alg (cb:: SymbolicDiscreteCallback ) = cb. reinitializealg
497
+ reinitialization_algs (cbs:: Vector{SymbolicDiscreteCallback} ) =
498
+ mapreduce (reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
499
+
484
500
function namespace_callback (cb:: SymbolicDiscreteCallback , s):: SymbolicDiscreteCallback
485
501
af = affects (cb)
486
502
af = af isa AbstractVector ? namespace_affect .(af, Ref (s)) : namespace_affect (af, s)
@@ -776,12 +792,13 @@ function generate_single_rootfinding_callback(
776
792
return ContinuousCallback (
777
793
cond, affect_function. affect, affect_function. affect_neg, rootfind = cb. rootfind,
778
794
initialize = isnothing (affect_function. initialize) ? SciMLBase. INITIALIZE_DEFAULT : (c, u, t, i) -> affect_function. initialize (i),
779
- finalize = isnothing (affect_function. finalize) ? SciMLBase. FINALIZE_DEFAULT : (c, u, t, i) -> affect_function. finalize (i))
795
+ finalize = isnothing (affect_function. finalize) ? SciMLBase. FINALIZE_DEFAULT : (c, u, t, i) -> affect_function. finalize (i),
796
+ initializealg = reinitialization_alg (cb))
780
797
end
781
798
782
799
function generate_vector_rootfinding_callback (
783
800
cbs, sys:: AbstractODESystem , dvs = unknowns (sys),
784
- ps = parameters (sys); rootfind = SciMLBase. RightRootFind, kwargs... )
801
+ ps = parameters (sys); rootfind = SciMLBase. RightRootFind, reinitialization = SciMLBase . CheckInit (), kwargs... )
785
802
eqs = map (cb -> flatten_equations (cb. eqs), cbs)
786
803
num_eqs = length .(eqs)
787
804
# fuse equations to create VectorContinuousCallback
@@ -847,7 +864,7 @@ function generate_vector_rootfinding_callback(
847
864
initialize = handle_optional_setup_fn (map (fn -> fn. initialize, affect_functions), SciMLBase. INITIALIZE_DEFAULT)
848
865
finalize = handle_optional_setup_fn (map (fn -> fn. finalize, affect_functions), SciMLBase. FINALIZE_DEFAULT)
849
866
return VectorContinuousCallback (
850
- cond, affect, affect_neg, length (eqs), rootfind = rootfind, initialize = initialize, finalize = finalize)
867
+ cond, affect, affect_neg, length (eqs), rootfind = rootfind, initialize = initialize, finalize = finalize, initializealg = reinitialization )
851
868
end
852
869
853
870
"""
@@ -893,18 +910,22 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
893
910
# group the cbs by what rootfind op they use
894
911
# groupby would be very useful here, but alas
895
912
cb_classes = Dict{
896
- @NamedTuple {rootfind:: SciMLBase.RootfindOpt }, Vector{SymbolicContinuousCallback}}()
913
+ @NamedTuple {
914
+ rootfind:: SciMLBase.RootfindOpt ,
915
+ reinitialization:: SciMLBase.DAEInitializationAlgorithm }, Vector{SymbolicContinuousCallback}}()
897
916
for cb in cbs
898
917
push! (
899
- get! (() -> SymbolicContinuousCallback[], cb_classes, (rootfind = cb. rootfind,)),
918
+ get! (() -> SymbolicContinuousCallback[], cb_classes, (
919
+ rootfind = cb. rootfind,
920
+ reinitialization = reinitialization_alg (cb))),
900
921
cb)
901
922
end
902
923
903
924
# generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order
904
925
compiled_callbacks = map (collect (pairs (sort! (
905
926
OrderedDict (cb_classes); by = p -> p. rootfind)))) do (equiv_class, cbs_in_class)
906
927
return generate_vector_rootfinding_callback (
907
- cbs_in_class, sys, dvs, ps; rootfind = equiv_class. rootfind, kwargs... )
928
+ cbs_in_class, sys, dvs, ps; rootfind = equiv_class. rootfind, reinitialization = equiv_class . reinitialization, kwargs... )
908
929
end
909
930
if length (compiled_callbacks) == 1
910
931
return compiled_callbacks[]
0 commit comments