@@ -106,16 +106,25 @@ The affect function updates the value at `x` in `modified` to be the result of e
106
106
ctx:: Any
107
107
end
108
108
109
- MutatingFunctionalAffect (f:: Function ;
110
- observed:: NamedTuple = NamedTuple {()} (()),
111
- modified:: NamedTuple = NamedTuple {()} (()),
112
- ctx= nothing ) = MutatingFunctionalAffect (f, collect (values (observed)), collect (keys (observed)), collect (values (modified)), collect (keys (modified)), ctx)
113
- MutatingFunctionalAffect (f:: Function , modified:: NamedTuple ; observed:: NamedTuple = NamedTuple {()} (()), ctx= nothing ) =
114
- MutatingFunctionalAffect (f, observed= observed, modified= modified, ctx= ctx)
115
- MutatingFunctionalAffect (f:: Function , modified:: NamedTuple , observed:: NamedTuple ; ctx= nothing ) =
116
- MutatingFunctionalAffect (f, observed= observed, modified= modified, ctx= ctx)
117
- MutatingFunctionalAffect (f:: Function , modified:: NamedTuple , observed:: NamedTuple , ctx) =
118
- MutatingFunctionalAffect (f, observed= observed, modified= modified, ctx= ctx)
109
+ function MutatingFunctionalAffect (f:: Function ;
110
+ observed:: NamedTuple = NamedTuple {()} (()),
111
+ modified:: NamedTuple = NamedTuple {()} (()),
112
+ ctx = nothing )
113
+ MutatingFunctionalAffect (f, collect (values (observed)), collect (keys (observed)),
114
+ collect (values (modified)), collect (keys (modified)), ctx)
115
+ end
116
+ function MutatingFunctionalAffect (f:: Function , modified:: NamedTuple ;
117
+ observed:: NamedTuple = NamedTuple {()} (()), ctx = nothing )
118
+ MutatingFunctionalAffect (f, observed = observed, modified = modified, ctx = ctx)
119
+ end
120
+ function MutatingFunctionalAffect (
121
+ f:: Function , modified:: NamedTuple , observed:: NamedTuple ; ctx = nothing )
122
+ MutatingFunctionalAffect (f, observed = observed, modified = modified, ctx = ctx)
123
+ end
124
+ function MutatingFunctionalAffect (
125
+ f:: Function , modified:: NamedTuple , observed:: NamedTuple , ctx)
126
+ MutatingFunctionalAffect (f, observed = observed, modified = modified, ctx = ctx)
127
+ end
119
128
120
129
func (f:: MutatingFunctionalAffect ) = f. f
121
130
context (a:: MutatingFunctionalAffect ) = a. ctx
@@ -126,8 +135,9 @@ modified(a::MutatingFunctionalAffect) = a.modified
126
135
modified_syms (a:: MutatingFunctionalAffect ) = a. mod_syms
127
136
128
137
function Base.:(== )(a1:: MutatingFunctionalAffect , a2:: MutatingFunctionalAffect )
129
- isequal (a1. f, a2. f) && isequal (a1. obs, a2. obs) && isequal (a1. modified, a2. modified) &&
130
- isequal (a1. obs_syms, a2. obs_syms) && isequal (a1. mod_syms, a2. mod_syms)&& isequal (a1. ctx, a2. ctx)
138
+ isequal (a1. f, a2. f) && isequal (a1. obs, a2. obs) && isequal (a1. modified, a2. modified) &&
139
+ isequal (a1. obs_syms, a2. obs_syms) && isequal (a1. mod_syms, a2. mod_syms) &&
140
+ isequal (a1. ctx, a2. ctx)
131
141
end
132
142
133
143
function Base. hash (a:: MutatingFunctionalAffect , s:: UInt )
@@ -839,10 +849,13 @@ function compile_user_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs.
839
849
end
840
850
end
841
851
842
- invalid_variables (sys, expr) = filter (x -> ! any (isequal (x), all_symbols (sys)), reduce (vcat, vars (expr); init= []))
843
- function unassignable_variables (sys, expr)
852
+ function invalid_variables (sys, expr)
853
+ filter (x -> ! any (isequal (x), all_symbols (sys)), reduce (vcat, vars (expr); init = []))
854
+ end
855
+ function unassignable_variables (sys, expr)
844
856
assignable_syms = vcat (unknowns (sys), parameters (sys))
845
- return filter (x -> ! any (isequal (x), assignable_syms), reduce (vcat, vars (expr); init= []))
857
+ return filter (
858
+ x -> ! any (isequal (x), assignable_syms), reduce (vcat, vars (expr); init = []))
846
859
end
847
860
848
861
function compile_user_affect (affect:: MutatingFunctionalAffect , sys, dvs, ps; kwargs... )
@@ -855,7 +868,8 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
855
868
=#
856
869
function check_dups (syms, exprs) # = (syms_dedup, exprs_dedup)
857
870
seen = Set {Symbol} ()
858
- syms_dedup = []; exprs_dedup = []
871
+ syms_dedup = []
872
+ exprs_dedup = []
859
873
for (sym, exp) in Iterators. zip (syms, exprs)
860
874
if ! in (sym, seen)
861
875
push! (syms_dedup, sym)
@@ -869,7 +883,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
869
883
end
870
884
871
885
obs_exprs = observed (affect)
872
- for oexpr in obs_exprs
886
+ for oexpr in obs_exprs
873
887
invalid_vars = invalid_variables (sys, oexpr)
874
888
if length (invalid_vars) > 0
875
889
error (" Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars) ; the variables may not have been added (e.g. if a component is missing)." )
@@ -880,7 +894,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
880
894
obs_size = size .(obs_exprs) # we will generate a work buffer of a ComponentArray that maps obs_syms to arrays of size obs_size
881
895
882
896
mod_exprs = modified (affect)
883
- for mexpr in mod_exprs
897
+ for mexpr in mod_exprs
884
898
if ! is_observed (sys, mexpr) && parameter_index (sys, mexpr) === nothing
885
899
error (" Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect." )
886
900
end
@@ -891,37 +905,50 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
891
905
end
892
906
mod_syms = modified_syms (affect)
893
907
mod_syms, mod_exprs = check_dups (mod_syms, mod_exprs)
894
- _, mod_og_val_fun = build_explicit_observed_function (sys, mod_exprs; return_inplace= true )
908
+ _, mod_og_val_fun = build_explicit_observed_function (
909
+ sys, mod_exprs; return_inplace = true )
895
910
896
911
overlapping_syms = intersect (mod_syms, obs_syms)
897
912
if length (overlapping_syms) > 0
898
913
@warn " The symbols $overlapping_syms are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value."
899
914
end
900
915
901
916
# sanity checks done! now build the data and update function for observed values
902
- mkzero (sz) = if sz === () 0.0 else zeros (sz) end
903
- _, obs_fun = build_explicit_observed_function (sys, reduce (vcat, Symbolics. scalarize .(obs_exprs); init = []); return_inplace= true )
904
- obs_component_array = ComponentArrays. ComponentArray (NamedTuple {(obs_syms..., )} (mkzero .(obs_size)))
917
+ mkzero (sz) =
918
+ if sz === ()
919
+ 0.0
920
+ else
921
+ zeros (sz)
922
+ end
923
+ _, obs_fun = build_explicit_observed_function (
924
+ sys, reduce (vcat, Symbolics. scalarize .(obs_exprs); init = []);
925
+ return_inplace = true )
926
+ obs_component_array = ComponentArrays. ComponentArray (NamedTuple {(obs_syms...,)} (mkzero .(obs_size)))
905
927
906
928
# okay so now to generate the stuff to assign it back into the system
907
929
# note that we reorder the componentarray to make the views coherent wrt the base array
908
930
mod_pairs = mod_exprs .=> mod_syms
909
931
mod_param_pairs = filter (v -> is_parameter (sys, v[1 ]), mod_pairs)
910
932
mod_unk_pairs = filter (v -> ! is_parameter (sys, v[1 ]), mod_pairs)
911
- _, mod_og_val_fun = build_explicit_observed_function (sys, reduce (vcat, [first .(mod_param_pairs); first .(mod_unk_pairs)]; init = []); return_inplace= true )
912
- upd_params_fun = setu (sys, reduce (vcat, Symbolics. scalarize .(first .(mod_param_pairs)); init = []))
913
- upd_unk_fun = setu (sys, reduce (vcat, Symbolics. scalarize .(first .(mod_unk_pairs)); init = []))
914
-
915
- upd_component_array = ComponentArrays. ComponentArray (NamedTuple {([last.(mod_param_pairs); last.(mod_unk_pairs)]...,)} (
916
- [collect (mkzero (size (e)) for e in first .(mod_param_pairs));
933
+ _, mod_og_val_fun = build_explicit_observed_function (
934
+ sys, reduce (vcat, [first .(mod_param_pairs); first .(mod_unk_pairs)]; init = []);
935
+ return_inplace = true )
936
+ upd_params_fun = setu (
937
+ sys, reduce (vcat, Symbolics. scalarize .(first .(mod_param_pairs)); init = []))
938
+ upd_unk_fun = setu (
939
+ sys, reduce (vcat, Symbolics. scalarize .(first .(mod_unk_pairs)); init = []))
940
+
941
+ upd_component_array = ComponentArrays. ComponentArray (NamedTuple{([last .(mod_param_pairs);
942
+ last .(mod_unk_pairs)]. .. ,)}(
943
+ [collect (mkzero (size (e)) for e in first .(mod_param_pairs));
917
944
collect (mkzero (size (e)) for e in first .(mod_unk_pairs))]))
918
945
upd_params_view = view (upd_component_array, last .(mod_param_pairs))
919
946
upd_unks_view = view (upd_component_array, last .(mod_unk_pairs))
920
947
let user_affect = func (affect), ctx = context (affect)
921
948
function (integ)
922
949
# update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
923
950
mod_og_val_fun (upd_component_array, integ. u, integ. p... , integ. t)
924
-
951
+
925
952
# update the observed values
926
953
obs_fun (obs_component_array, integ. u, integ. p... , integ. t)
927
954
@@ -934,7 +961,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
934
961
user_affect (upd_component_array, obs_component_array)
935
962
elseif applicable (user_affect, upd_component_array)
936
963
user_affect (upd_component_array)
937
- else
964
+ else
938
965
@error " User affect function $user_affect needs to implement one of the supported MutatingFunctionalAffect callback forms; see the MutatingFunctionalAffect docstring for more details"
939
966
user_affect (upd_component_array, obs_component_array, integ, ctx) # this WILL error but it'll give a more sensible message
940
967
end
0 commit comments