@@ -116,29 +116,33 @@ function ImperativeAffect(f::Function;
116116 modified:: NamedTuple = NamedTuple {()} (()),
117117 ctx = nothing ,
118118 skip_checks = false )
119- ImperativeAffect (f,
119+ ImperativeAffect (f,
120120 collect (values (observed)), collect (keys (observed)),
121- collect (values (modified)), collect (keys (modified)),
121+ collect (values (modified)), collect (keys (modified)),
122122 ctx, skip_checks)
123123end
124124function ImperativeAffect (f:: Function , modified:: NamedTuple ;
125- observed:: NamedTuple = NamedTuple {()} (()), ctx = nothing , skip_checks= false )
126- ImperativeAffect (f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
125+ observed:: NamedTuple = NamedTuple {()} (()), ctx = nothing , skip_checks = false )
126+ ImperativeAffect (
127+ f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
127128end
128129function ImperativeAffect (
129- f:: Function , modified:: NamedTuple , observed:: NamedTuple ; ctx = nothing , skip_checks= false )
130- ImperativeAffect (f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
130+ f:: Function , modified:: NamedTuple , observed:: NamedTuple ; ctx = nothing , skip_checks = false )
131+ ImperativeAffect (
132+ f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
131133end
132134function ImperativeAffect (
133- f:: Function , modified:: NamedTuple , observed:: NamedTuple , ctx; skip_checks= false )
134- ImperativeAffect (f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
135+ f:: Function , modified:: NamedTuple , observed:: NamedTuple , ctx; skip_checks = false )
136+ ImperativeAffect (
137+ f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
135138end
136139
137- function Base. show (io:: IO , mfa:: ImperativeAffect )
138- obs_vals = join (map ((ob,nm) -> " $ob => $nm " , mfa. obs, mfa. obs_syms), " , " )
139- mod_vals = join (map ((md,nm) -> " $md => $nm " , mfa. modified, mfa. mod_syms), " , " )
140+ function Base. show (io:: IO , mfa:: ImperativeAffect )
141+ obs_vals = join (map ((ob, nm) -> " $ob => $nm " , mfa. obs, mfa. obs_syms), " , " )
142+ mod_vals = join (map ((md, nm) -> " $md => $nm " , mfa. modified, mfa. mod_syms), " , " )
140143 affect = mfa. f
141- print (io, " ImperativeAffect(observed: [$obs_vals ], modified: [$mod_vals ], affect:$affect )" )
144+ print (io,
145+ " ImperativeAffect(observed: [$obs_vals ], modified: [$mod_vals ], affect:$affect )" )
142146end
143147func (f:: ImperativeAffect ) = f. f
144148context (a:: ImperativeAffect ) = a. ctx
@@ -234,24 +238,25 @@ struct SymbolicContinuousCallback
234238 affect_neg:: Union{Vector{Equation}, FunctionalAffect, ImperativeAffect, Nothing}
235239 rootfind:: SciMLBase.RootfindOpt
236240 reinitializealg:: SciMLBase.DAEInitializationAlgorithm
237- function SymbolicContinuousCallback (;
238- eqs:: Vector{Equation} ,
239- affect = NULL_AFFECT,
240- affect_neg = affect,
241- rootfind = SciMLBase. LeftRootFind,
242- initialize= NULL_AFFECT,
243- finalize= NULL_AFFECT,
244- reinitializealg= SciMLBase. CheckInit ())
245- new (eqs, initialize, finalize, make_affect (affect), make_affect (affect_neg), rootfind, reinitializealg)
241+ function SymbolicContinuousCallback (;
242+ eqs:: Vector{Equation} ,
243+ affect = NULL_AFFECT,
244+ affect_neg = affect,
245+ rootfind = SciMLBase. LeftRootFind,
246+ initialize = NULL_AFFECT,
247+ finalize = NULL_AFFECT,
248+ reinitializealg = SciMLBase. CheckInit ())
249+ new (eqs, initialize, finalize, make_affect (affect),
250+ make_affect (affect_neg), rootfind, reinitializealg)
246251 end # Default affect to nothing
247252end
248253make_affect (affect) = affect
249254make_affect (affect:: Tuple ) = FunctionalAffect (affect... )
250255make_affect (affect:: NamedTuple ) = FunctionalAffect (; affect... )
251256
252257function Base.:(== )(e1:: SymbolicContinuousCallback , e2:: SymbolicContinuousCallback )
253- isequal (e1. eqs, e2. eqs) && isequal (e1. affect, e2. affect) &&
254- isequal (e1. initialize, e2. initialize) && isequal (e1. finalize, e2. finalize) &&
258+ isequal (e1. eqs, e2. eqs) && isequal (e1. affect, e2. affect) &&
259+ isequal (e1. initialize, e2. initialize) && isequal (e1. finalize, e2. finalize) &&
255260 isequal (e1. affect_neg, e2. affect_neg) && isequal (e1. rootfind, e2. rootfind)
256261end
257262Base. isempty (cb:: SymbolicContinuousCallback ) = isempty (cb. eqs)
@@ -266,10 +271,9 @@ function Base.hash(cb::SymbolicContinuousCallback, s::UInt)
266271 hash (cb. rootfind, s)
267272end
268273
269-
270274function Base. show (io:: IO , cb:: SymbolicContinuousCallback )
271275 indent = get (io, :indent , 0 )
272- iio = IOContext (io, :indent => indent+ 1 )
276+ iio = IOContext (io, :indent => indent + 1 )
273277 print (io, " SymbolicContinuousCallback(" )
274278 print (iio, " Equations:" )
275279 show (iio, equations (cb))
298302
299303function Base. show (io:: IO , mime:: MIME"text/plain" , cb:: SymbolicContinuousCallback )
300304 indent = get (io, :indent , 0 )
301- iio = IOContext (io, :indent => indent+ 1 )
305+ iio = IOContext (io, :indent => indent + 1 )
302306 println (io, " SymbolicContinuousCallback:" )
303307 println (iio, " Equations:" )
304308 show (iio, mime, equations (cb))
@@ -338,14 +342,18 @@ end # wrap eq in vector
338342SymbolicContinuousCallback (p:: Pair ) = SymbolicContinuousCallback (p[1 ], p[2 ])
339343SymbolicContinuousCallback (cb:: SymbolicContinuousCallback ) = cb # passthrough
340344function SymbolicContinuousCallback (eqs:: Equation , affect = NULL_AFFECT;
341- affect_neg = affect, rootfind = SciMLBase. LeftRootFind, initialize = NULL_AFFECT, finalize = NULL_AFFECT)
345+ affect_neg = affect, rootfind = SciMLBase. LeftRootFind,
346+ initialize = NULL_AFFECT, finalize = NULL_AFFECT)
342347 SymbolicContinuousCallback (
343- eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind, initialize= initialize, finalize= finalize)
348+ eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind,
349+ initialize = initialize, finalize = finalize)
344350end
345351function SymbolicContinuousCallback (eqs:: Vector{Equation} , affect = NULL_AFFECT;
346- affect_neg = affect, rootfind = SciMLBase. LeftRootFind, initialize = NULL_AFFECT, finalize = NULL_AFFECT)
352+ affect_neg = affect, rootfind = SciMLBase. LeftRootFind,
353+ initialize = NULL_AFFECT, finalize = NULL_AFFECT)
347354 SymbolicContinuousCallback (
348- eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind, initialize= initialize, finalize= finalize)
355+ eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind,
356+ initialize = initialize, finalize = finalize)
349357end
350358
351359SymbolicContinuousCallbacks (cb:: SymbolicContinuousCallback ) = [cb]
@@ -385,8 +393,10 @@ function finalize_affects(cbs::Vector{SymbolicContinuousCallback})
385393end
386394
387395reinitialization_alg (cb:: SymbolicContinuousCallback ) = cb. reinitializealg
388- reinitialization_algs (cbs:: Vector{SymbolicContinuousCallback} ) =
389- mapreduce (reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
396+ function reinitialization_algs (cbs:: Vector{SymbolicContinuousCallback} )
397+ mapreduce (
398+ reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
399+ end
390400
391401namespace_affects (af:: Vector , s) = Equation[namespace_affect (a, s) for a in af]
392402namespace_affects (af:: FunctionalAffect , s) = namespace_affect (af, s)
@@ -436,7 +446,8 @@ struct SymbolicDiscreteCallback
436446 affects:: Any
437447 reinitializealg:: SciMLBase.DAEInitializationAlgorithm
438448
439- function SymbolicDiscreteCallback (condition, affects = NULL_AFFECT, reinitializealg= SciMLBase. CheckInit ())
449+ function SymbolicDiscreteCallback (
450+ condition, affects = NULL_AFFECT, reinitializealg = SciMLBase. CheckInit ())
440451 c = scalarize_condition (condition)
441452 a = scalarize_affects (affects)
442453 new (c, a, reinitializealg)
@@ -498,8 +509,10 @@ function affects(cbs::Vector{SymbolicDiscreteCallback})
498509end
499510
500511reinitialization_alg (cb:: SymbolicDiscreteCallback ) = cb. reinitializealg
501- reinitialization_algs (cbs:: Vector{SymbolicDiscreteCallback} ) =
502- mapreduce (reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
512+ function reinitialization_algs (cbs:: Vector{SymbolicDiscreteCallback} )
513+ mapreduce (
514+ reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
515+ end
503516
504517function namespace_callback (cb:: SymbolicDiscreteCallback , s):: SymbolicDiscreteCallback
505518 af = affects (cb)
@@ -781,7 +794,8 @@ function generate_single_rootfinding_callback(
781794 end
782795 end
783796
784- user_initfun = isnothing (affect_function. initialize) ? SciMLBase. INITIALIZE_DEFAULT : (c, u, t, i) -> affect_function. initialize (i)
797+ user_initfun = isnothing (affect_function. initialize) ? SciMLBase. INITIALIZE_DEFAULT :
798+ (c, u, t, i) -> affect_function. initialize (i)
785799 if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing &&
786800 (save_idxs = get (ic. callback_to_clocks, cb, nothing )) != = nothing
787801 initfn = let save_idxs = save_idxs
@@ -795,17 +809,19 @@ function generate_single_rootfinding_callback(
795809 else
796810 initfn = user_initfun
797811 end
798-
812+
799813 return ContinuousCallback (
800- cond, affect_function. affect, affect_function. affect_neg, rootfind = cb. rootfind,
801- initialize = initfn,
802- finalize = isnothing (affect_function. finalize) ? SciMLBase. FINALIZE_DEFAULT : (c, u, t, i) -> affect_function. finalize (i),
814+ cond, affect_function. affect, affect_function. affect_neg, rootfind = cb. rootfind,
815+ initialize = initfn,
816+ finalize = isnothing (affect_function. finalize) ? SciMLBase. FINALIZE_DEFAULT :
817+ (c, u, t, i) -> affect_function. finalize (i),
803818 initializealg = reinitialization_alg (cb))
804819end
805820
806821function generate_vector_rootfinding_callback (
807822 cbs, sys:: AbstractODESystem , dvs = unknowns (sys),
808- ps = parameters (sys); rootfind = SciMLBase. RightRootFind, reinitialization = SciMLBase. CheckInit (), kwargs... )
823+ ps = parameters (sys); rootfind = SciMLBase. RightRootFind,
824+ reinitialization = SciMLBase. CheckInit (), kwargs... )
809825 eqs = map (cb -> flatten_equations (cb. eqs), cbs)
810826 num_eqs = length .(eqs)
811827 # fuse equations to create VectorContinuousCallback
@@ -821,11 +837,12 @@ function generate_vector_rootfinding_callback(
821837 sys, rhss, dvs, ps; expression = Val{false }, kwargs... )
822838
823839 affect_functions = @NamedTuple {
824- affect:: Function ,
825- affect_neg:: Union{Function, Nothing} ,
826- initialize:: Union{Function, Nothing} ,
840+ affect:: Function ,
841+ affect_neg:: Union{Function, Nothing} ,
842+ initialize:: Union{Function, Nothing} ,
827843 finalize:: Union{Function, Nothing} }[
828- compile_affect_fn (cb, sys, dvs, ps, kwargs) for cb in cbs]
844+ compile_affect_fn (cb, sys, dvs, ps, kwargs)
845+ for cb in cbs]
829846 cond = function (out, u, t, integ)
830847 rf_ip (out, u, parameter_values (integ), t)
831848 end
@@ -861,17 +878,20 @@ function generate_vector_rootfinding_callback(
861878 if isnothing (func)
862879 continue
863880 else
864- func (integ)
881+ func (integ)
865882 end
866883 end
867884 end
868885 end
869886 end
870887 end
871- initialize = handle_optional_setup_fn (map (fn -> fn. initialize, affect_functions), SciMLBase. INITIALIZE_DEFAULT)
872- finalize = handle_optional_setup_fn (map (fn -> fn. finalize, affect_functions), SciMLBase. FINALIZE_DEFAULT)
888+ initialize = handle_optional_setup_fn (
889+ map (fn -> fn. initialize, affect_functions), SciMLBase. INITIALIZE_DEFAULT)
890+ finalize = handle_optional_setup_fn (
891+ map (fn -> fn. finalize, affect_functions), SciMLBase. FINALIZE_DEFAULT)
873892 return VectorContinuousCallback (
874- cond, affect, affect_neg, length (eqs), rootfind = rootfind, initialize = initialize, finalize = finalize, initializealg = reinitialization)
893+ cond, affect, affect_neg, length (eqs), rootfind = rootfind, initialize = initialize,
894+ finalize = finalize, initializealg = reinitialization)
875895end
876896
877897"""
@@ -881,8 +901,8 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
881901 eq_aff = affects (cb)
882902 eq_neg_aff = affect_negs (cb)
883903 affect = compile_affect (eq_aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
884- function compile_optional_affect (aff, default= nothing )
885- if isnothing (aff) || aff== default
904+ function compile_optional_affect (aff, default = nothing )
905+ if isnothing (aff) || aff == default
886906 return nothing
887907 else
888908 return compile_affect (aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
@@ -918,21 +938,23 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
918938 # groupby would be very useful here, but alas
919939 cb_classes = Dict{
920940 @NamedTuple {
921- rootfind:: SciMLBase.RootfindOpt ,
941+ rootfind:: SciMLBase.RootfindOpt ,
922942 reinitialization:: SciMLBase.DAEInitializationAlgorithm }, Vector{SymbolicContinuousCallback}}()
923943 for cb in cbs
924944 push! (
925- get! (() -> SymbolicContinuousCallback[], cb_classes, (
926- rootfind = cb. rootfind,
927- reinitialization = reinitialization_alg (cb))),
945+ get! (() -> SymbolicContinuousCallback[], cb_classes,
946+ (
947+ rootfind = cb. rootfind,
948+ reinitialization = reinitialization_alg (cb))),
928949 cb)
929950 end
930951
931952 # generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order
932953 compiled_callbacks = map (collect (pairs (sort! (
933954 OrderedDict (cb_classes); by = p -> p. rootfind)))) do (equiv_class, cbs_in_class)
934955 return generate_vector_rootfinding_callback (
935- cbs_in_class, sys, dvs, ps; rootfind = equiv_class. rootfind, reinitialization= equiv_class. reinitialization, kwargs... )
956+ cbs_in_class, sys, dvs, ps; rootfind = equiv_class. rootfind,
957+ reinitialization = equiv_class. reinitialization, kwargs... )
936958 end
937959 if length (compiled_callbacks) == 1
938960 return compiled_callbacks[]
@@ -984,29 +1006,34 @@ function invalid_variables(sys, expr)
9841006 filter (x -> ! any (isequal (x), all_symbols (sys)), reduce (vcat, vars (expr); init = []))
9851007end
9861008function unassignable_variables (sys, expr)
987- assignable_syms = reduce (vcat, Symbolics. scalarize .(vcat (unknowns (sys), parameters (sys))); init= [])
1009+ assignable_syms = reduce (
1010+ vcat, Symbolics. scalarize .(vcat (unknowns (sys), parameters (sys))); init = [])
9881011 written = reduce (vcat, Symbolics. scalarize .(vars (expr)); init = [])
9891012 return filter (
9901013 x -> ! any (isequal (x), assignable_syms), written)
9911014end
9921015
993- @generated function _generated_writeback (integ, setters:: NamedTuple{NS1,<:Tuple} , values:: NamedTuple{NS2, <:Tuple} ) where {NS1, NS2}
1016+ @generated function _generated_writeback (integ, setters:: NamedTuple{NS1, <:Tuple} ,
1017+ values:: NamedTuple{NS2, <:Tuple} ) where {NS1, NS2}
9941018 setter_exprs = []
995- for name in NS2
1019+ for name in NS2
9961020 if ! (name in NS1)
9971021 missing_name = " Tried to write back to $name from affect; only declared states ($NS1 ) may be written to."
9981022 error (missing_name)
9991023 end
10001024 push! (setter_exprs, :(setters.$ name (integ, values.$ name)))
10011025 end
1002- return :(begin $ (setter_exprs... ) end )
1026+ return :(begin
1027+ $ (setter_exprs... )
1028+ end )
10031029end
10041030
10051031function check_assignable (sys, sym)
10061032 if symbolic_type (sym) == ScalarSymbolic ()
10071033 is_variable (sys, sym) || is_parameter (sys, sym)
10081034 elseif symbolic_type (sym) == ArraySymbolic ()
1009- is_variable (sys, sym) || is_parameter (sys, sym) || all (x -> check_assignable (sys, x), collect (sym))
1035+ is_variable (sys, sym) || is_parameter (sys, sym) ||
1036+ all (x -> check_assignable (sys, x), collect (sym))
10101037 elseif sym isa Union{AbstractArray, Tuple}
10111038 all (x -> check_assignable (sys, x), sym)
10121039 else
@@ -1084,13 +1111,13 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs.
10841111
10851112 # okay so now to generate the stuff to assign it back into the system
10861113 mod_pairs = mod_exprs .=> mod_syms
1087- mod_names = (mod_syms... , )
1114+ mod_names = (mod_syms... ,)
10881115 mod_og_val_fun = build_explicit_observed_function (
10891116 sys, Symbolics. scalarize .(first .(mod_pairs));
10901117 array_type = :tuple )
10911118
10921119 upd_funs = NamedTuple {mod_names} ((setu .((sys,), first .(mod_pairs))... ,))
1093-
1120+
10941121 if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing
10951122 save_idxs = get (ic. callback_to_clocks, cb, Int[])
10961123 else
@@ -1104,10 +1131,12 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs.
11041131 upd_component_array = NamedTuple {mod_names} (modvals)
11051132
11061133 # update the observed values
1107- obs_component_array = NamedTuple {obs_sym_tuple} (obs_fun (integ. u, integ. p, integ. t))
1134+ obs_component_array = NamedTuple {obs_sym_tuple} (obs_fun (
1135+ integ. u, integ. p, integ. t))
11081136
11091137 # let the user do their thing
1110- modvals = if applicable (user_affect, upd_component_array, obs_component_array, ctx, integ)
1138+ modvals = if applicable (
1139+ user_affect, upd_component_array, obs_component_array, ctx, integ)
11111140 user_affect (upd_component_array, obs_component_array, ctx, integ)
11121141 elseif applicable (user_affect, upd_component_array, obs_component_array, ctx)
11131142 user_affect (upd_component_array, obs_component_array, ctx)
@@ -1122,15 +1151,16 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs.
11221151
11231152 # write the new values back to the integrator
11241153 _generated_writeback (integ, upd_funs, modvals)
1125-
1154+
11261155 for idx in save_idxs
11271156 SciMLBase. save_discretes! (integ, idx)
11281157 end
11291158 end
11301159 end
11311160end
11321161
1133- function compile_affect (affect:: Union{FunctionalAffect, ImperativeAffect} , cb, sys, dvs, ps; kwargs... )
1162+ function compile_affect (
1163+ affect:: Union{FunctionalAffect, ImperativeAffect} , cb, sys, dvs, ps; kwargs... )
11341164 compile_user_affect (affect, cb, sys, dvs, ps; kwargs... )
11351165end
11361166
0 commit comments