@@ -106,16 +106,25 @@ The affect function updates the value at `x` in `modified` to be the result of e
106106 ctx:: Any
107107end
108108
109- MutatingFunctionalAffect (f:: Function ;
110- observed:: NamedTuple = NamedTuple {()} (()),
111- modified:: NamedTuple = NamedTuple {()} (()),
112- ctx= nothing ) = MutatingFunctionalAffect (f, collect (values (observed)), collect (keys (observed)), collect (values (modified)), collect (keys (modified)), ctx)
113- MutatingFunctionalAffect (f:: Function , modified:: NamedTuple ; observed:: NamedTuple = NamedTuple {()} (()), ctx= nothing ) =
114- MutatingFunctionalAffect (f, observed= observed, modified= modified, ctx= ctx)
115- MutatingFunctionalAffect (f:: Function , modified:: NamedTuple , observed:: NamedTuple ; ctx= nothing ) =
116- MutatingFunctionalAffect (f, observed= observed, modified= modified, ctx= ctx)
117- MutatingFunctionalAffect (f:: Function , modified:: NamedTuple , observed:: NamedTuple , ctx) =
118- MutatingFunctionalAffect (f, observed= observed, modified= modified, ctx= ctx)
109+ function MutatingFunctionalAffect (f:: Function ;
110+ observed:: NamedTuple = NamedTuple {()} (()),
111+ modified:: NamedTuple = NamedTuple {()} (()),
112+ ctx = nothing )
113+ MutatingFunctionalAffect (f, collect (values (observed)), collect (keys (observed)),
114+ collect (values (modified)), collect (keys (modified)), ctx)
115+ end
116+ function MutatingFunctionalAffect (f:: Function , modified:: NamedTuple ;
117+ observed:: NamedTuple = NamedTuple {()} (()), ctx = nothing )
118+ MutatingFunctionalAffect (f, observed = observed, modified = modified, ctx = ctx)
119+ end
120+ function MutatingFunctionalAffect (
121+ f:: Function , modified:: NamedTuple , observed:: NamedTuple ; ctx = nothing )
122+ MutatingFunctionalAffect (f, observed = observed, modified = modified, ctx = ctx)
123+ end
124+ function MutatingFunctionalAffect (
125+ f:: Function , modified:: NamedTuple , observed:: NamedTuple , ctx)
126+ MutatingFunctionalAffect (f, observed = observed, modified = modified, ctx = ctx)
127+ end
119128
120129func (f:: MutatingFunctionalAffect ) = f. f
121130context (a:: MutatingFunctionalAffect ) = a. ctx
@@ -126,8 +135,9 @@ modified(a::MutatingFunctionalAffect) = a.modified
126135modified_syms (a:: MutatingFunctionalAffect ) = a. mod_syms
127136
128137function Base.:(== )(a1:: MutatingFunctionalAffect , a2:: MutatingFunctionalAffect )
129- isequal (a1. f, a2. f) && isequal (a1. obs, a2. obs) && isequal (a1. modified, a2. modified) &&
130- isequal (a1. obs_syms, a2. obs_syms) && isequal (a1. mod_syms, a2. mod_syms)&& isequal (a1. ctx, a2. ctx)
138+ isequal (a1. f, a2. f) && isequal (a1. obs, a2. obs) && isequal (a1. modified, a2. modified) &&
139+ isequal (a1. obs_syms, a2. obs_syms) && isequal (a1. mod_syms, a2. mod_syms) &&
140+ isequal (a1. ctx, a2. ctx)
131141end
132142
133143function Base. hash (a:: MutatingFunctionalAffect , s:: UInt )
@@ -237,11 +247,13 @@ SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2])
237247SymbolicContinuousCallback (cb:: SymbolicContinuousCallback ) = cb # passthrough
238248function SymbolicContinuousCallback (eqs:: Equation , affect = NULL_AFFECT;
239249 affect_neg = affect, rootfind = SciMLBase. LeftRootFind)
240- SymbolicContinuousCallback (eqs= [eqs], affect= affect, affect_neg= affect_neg, rootfind= rootfind)
250+ SymbolicContinuousCallback (
251+ eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind)
241252end
242253function SymbolicContinuousCallback (eqs:: Vector{Equation} , affect = NULL_AFFECT;
243254 affect_neg = affect, rootfind = SciMLBase. LeftRootFind)
244- SymbolicContinuousCallback (eqs= eqs, affect= affect, affect_neg= affect_neg, rootfind= rootfind)
255+ SymbolicContinuousCallback (
256+ eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind)
245257end
246258
247259SymbolicContinuousCallbacks (cb:: SymbolicContinuousCallback ) = [cb]
@@ -765,10 +777,13 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
765777 end
766778end
767779
768- invalid_variables (sys, expr) = filter (x -> ! any (isequal (x), all_symbols (sys)), reduce (vcat, vars (expr); init= []))
769- function unassignable_variables (sys, expr)
780+ function invalid_variables (sys, expr)
781+ filter (x -> ! any (isequal (x), all_symbols (sys)), reduce (vcat, vars (expr); init = []))
782+ end
783+ function unassignable_variables (sys, expr)
770784 assignable_syms = vcat (unknowns (sys), parameters (sys))
771- return filter (x -> ! any (isequal (x), assignable_syms), reduce (vcat, vars (expr); init= []))
785+ return filter (
786+ x -> ! any (isequal (x), assignable_syms), reduce (vcat, vars (expr); init = []))
772787end
773788
774789function compile_user_affect (affect:: MutatingFunctionalAffect , sys, dvs, ps; kwargs... )
@@ -781,7 +796,8 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
781796 =#
782797 function check_dups (syms, exprs) # = (syms_dedup, exprs_dedup)
783798 seen = Set {Symbol} ()
784- syms_dedup = []; exprs_dedup = []
799+ syms_dedup = []
800+ exprs_dedup = []
785801 for (sym, exp) in Iterators. zip (syms, exprs)
786802 if ! in (sym, seen)
787803 push! (syms_dedup, sym)
@@ -795,7 +811,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
795811 end
796812
797813 obs_exprs = observed (affect)
798- for oexpr in obs_exprs
814+ for oexpr in obs_exprs
799815 invalid_vars = invalid_variables (sys, oexpr)
800816 if length (invalid_vars) > 0
801817 error (" Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars) ; the variables may not have been added (e.g. if a component is missing)." )
@@ -806,7 +822,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
806822 obs_size = size .(obs_exprs) # we will generate a work buffer of a ComponentArray that maps obs_syms to arrays of size obs_size
807823
808824 mod_exprs = modified (affect)
809- for mexpr in mod_exprs
825+ for mexpr in mod_exprs
810826 if ! is_observed (sys, mexpr) && parameter_index (sys, mexpr) === nothing
811827 error (" Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect." )
812828 end
@@ -817,37 +833,50 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
817833 end
818834 mod_syms = modified_syms (affect)
819835 mod_syms, mod_exprs = check_dups (mod_syms, mod_exprs)
820- _, mod_og_val_fun = build_explicit_observed_function (sys, mod_exprs; return_inplace= true )
836+ _, mod_og_val_fun = build_explicit_observed_function (
837+ sys, mod_exprs; return_inplace = true )
821838
822839 overlapping_syms = intersect (mod_syms, obs_syms)
823840 if length (overlapping_syms) > 0
824841 @warn " The symbols $overlapping_syms are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value."
825842 end
826843
827844 # sanity checks done! now build the data and update function for observed values
828- mkzero (sz) = if sz === () 0.0 else zeros (sz) end
829- _, obs_fun = build_explicit_observed_function (sys, reduce (vcat, Symbolics. scalarize .(obs_exprs); init = []); return_inplace= true )
830- obs_component_array = ComponentArrays. ComponentArray (NamedTuple {(obs_syms..., )} (mkzero .(obs_size)))
845+ mkzero (sz) =
846+ if sz === ()
847+ 0.0
848+ else
849+ zeros (sz)
850+ end
851+ _, obs_fun = build_explicit_observed_function (
852+ sys, reduce (vcat, Symbolics. scalarize .(obs_exprs); init = []);
853+ return_inplace = true )
854+ obs_component_array = ComponentArrays. ComponentArray (NamedTuple {(obs_syms...,)} (mkzero .(obs_size)))
831855
832856 # okay so now to generate the stuff to assign it back into the system
833857 # note that we reorder the componentarray to make the views coherent wrt the base array
834858 mod_pairs = mod_exprs .=> mod_syms
835859 mod_param_pairs = filter (v -> is_parameter (sys, v[1 ]), mod_pairs)
836860 mod_unk_pairs = filter (v -> ! is_parameter (sys, v[1 ]), mod_pairs)
837- _, mod_og_val_fun = build_explicit_observed_function (sys, reduce (vcat, [first .(mod_param_pairs); first .(mod_unk_pairs)]; init = []); return_inplace= true )
838- upd_params_fun = setu (sys, reduce (vcat, Symbolics. scalarize .(first .(mod_param_pairs)); init = []))
839- upd_unk_fun = setu (sys, reduce (vcat, Symbolics. scalarize .(first .(mod_unk_pairs)); init = []))
840-
841- upd_component_array = ComponentArrays. ComponentArray (NamedTuple {([last.(mod_param_pairs); last.(mod_unk_pairs)]...,)} (
842- [collect (mkzero (size (e)) for e in first .(mod_param_pairs));
861+ _, mod_og_val_fun = build_explicit_observed_function (
862+ sys, reduce (vcat, [first .(mod_param_pairs); first .(mod_unk_pairs)]; init = []);
863+ return_inplace = true )
864+ upd_params_fun = setu (
865+ sys, reduce (vcat, Symbolics. scalarize .(first .(mod_param_pairs)); init = []))
866+ upd_unk_fun = setu (
867+ sys, reduce (vcat, Symbolics. scalarize .(first .(mod_unk_pairs)); init = []))
868+
869+ upd_component_array = ComponentArrays. ComponentArray (NamedTuple{([last .(mod_param_pairs);
870+ last .(mod_unk_pairs)]. .. ,)}(
871+ [collect (mkzero (size (e)) for e in first .(mod_param_pairs));
843872 collect (mkzero (size (e)) for e in first .(mod_unk_pairs))]))
844873 upd_params_view = view (upd_component_array, last .(mod_param_pairs))
845874 upd_unks_view = view (upd_component_array, last .(mod_unk_pairs))
846875 let user_affect = func (affect), ctx = context (affect)
847876 function (integ)
848877 # update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
849878 mod_og_val_fun (upd_component_array, integ. u, integ. p... , integ. t)
850-
879+
851880 # update the observed values
852881 obs_fun (obs_component_array, integ. u, integ. p... , integ. t)
853882
@@ -860,7 +889,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
860889 user_affect (upd_component_array, obs_component_array)
861890 elseif applicable (user_affect, upd_component_array)
862891 user_affect (upd_component_array)
863- else
892+ else
864893 @error " User affect function $user_affect needs to implement one of the supported MutatingFunctionalAffect callback forms; see the MutatingFunctionalAffect docstring for more details"
865894 user_affect (upd_component_array, obs_component_array, integ, ctx) # this WILL error but it'll give a more sensible message
866895 end
0 commit comments