Skip to content

Commit 9948de0

Browse files
committed
Run formatter
1 parent fd0125d commit 9948de0

File tree

2 files changed

+158
-104
lines changed

2 files changed

+158
-104
lines changed

src/systems/callbacks.jl

Lines changed: 57 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,25 @@ The affect function updates the value at `x` in `modified` to be the result of e
106106
ctx::Any
107107
end
108108

109-
MutatingFunctionalAffect(f::Function;
110-
observed::NamedTuple = NamedTuple{()}(()),
111-
modified::NamedTuple = NamedTuple{()}(()),
112-
ctx=nothing) = MutatingFunctionalAffect(f, collect(values(observed)), collect(keys(observed)), collect(values(modified)), collect(keys(modified)), ctx)
113-
MutatingFunctionalAffect(f::Function, modified::NamedTuple; observed::NamedTuple = NamedTuple{()}(()), ctx=nothing) =
114-
MutatingFunctionalAffect(f, observed=observed, modified=modified, ctx=ctx)
115-
MutatingFunctionalAffect(f::Function, modified::NamedTuple, observed::NamedTuple; ctx=nothing) =
116-
MutatingFunctionalAffect(f, observed=observed, modified=modified, ctx=ctx)
117-
MutatingFunctionalAffect(f::Function, modified::NamedTuple, observed::NamedTuple, ctx) =
118-
MutatingFunctionalAffect(f, observed=observed, modified=modified, ctx=ctx)
109+
function MutatingFunctionalAffect(f::Function;
110+
observed::NamedTuple = NamedTuple{()}(()),
111+
modified::NamedTuple = NamedTuple{()}(()),
112+
ctx = nothing)
113+
MutatingFunctionalAffect(f, collect(values(observed)), collect(keys(observed)),
114+
collect(values(modified)), collect(keys(modified)), ctx)
115+
end
116+
function MutatingFunctionalAffect(f::Function, modified::NamedTuple;
117+
observed::NamedTuple = NamedTuple{()}(()), ctx = nothing)
118+
MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx)
119+
end
120+
function MutatingFunctionalAffect(
121+
f::Function, modified::NamedTuple, observed::NamedTuple; ctx = nothing)
122+
MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx)
123+
end
124+
function MutatingFunctionalAffect(
125+
f::Function, modified::NamedTuple, observed::NamedTuple, ctx)
126+
MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx)
127+
end
119128

120129
func(f::MutatingFunctionalAffect) = f.f
121130
context(a::MutatingFunctionalAffect) = a.ctx
@@ -126,8 +135,9 @@ modified(a::MutatingFunctionalAffect) = a.modified
126135
modified_syms(a::MutatingFunctionalAffect) = a.mod_syms
127136

128137
function Base.:(==)(a1::MutatingFunctionalAffect, a2::MutatingFunctionalAffect)
129-
isequal(a1.f, a2.f) && isequal(a1.obs, a2.obs) && isequal(a1.modified, a2.modified) &&
130-
isequal(a1.obs_syms, a2.obs_syms) && isequal(a1.mod_syms, a2.mod_syms)&& isequal(a1.ctx, a2.ctx)
138+
isequal(a1.f, a2.f) && isequal(a1.obs, a2.obs) && isequal(a1.modified, a2.modified) &&
139+
isequal(a1.obs_syms, a2.obs_syms) && isequal(a1.mod_syms, a2.mod_syms) &&
140+
isequal(a1.ctx, a2.ctx)
131141
end
132142

133143
function Base.hash(a::MutatingFunctionalAffect, s::UInt)
@@ -839,10 +849,13 @@ function compile_user_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs.
839849
end
840850
end
841851

842-
invalid_variables(sys, expr) = filter(x -> !any(isequal(x), all_symbols(sys)), reduce(vcat, vars(expr); init=[]))
843-
function unassignable_variables(sys, expr)
852+
function invalid_variables(sys, expr)
853+
filter(x -> !any(isequal(x), all_symbols(sys)), reduce(vcat, vars(expr); init = []))
854+
end
855+
function unassignable_variables(sys, expr)
844856
assignable_syms = vcat(unknowns(sys), parameters(sys))
845-
return filter(x -> !any(isequal(x), assignable_syms), reduce(vcat, vars(expr); init=[]))
857+
return filter(
858+
x -> !any(isequal(x), assignable_syms), reduce(vcat, vars(expr); init = []))
846859
end
847860

848861
function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwargs...)
@@ -855,7 +868,8 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
855868
=#
856869
function check_dups(syms, exprs) # = (syms_dedup, exprs_dedup)
857870
seen = Set{Symbol}()
858-
syms_dedup = []; exprs_dedup = []
871+
syms_dedup = []
872+
exprs_dedup = []
859873
for (sym, exp) in Iterators.zip(syms, exprs)
860874
if !in(sym, seen)
861875
push!(syms_dedup, sym)
@@ -869,7 +883,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
869883
end
870884

871885
obs_exprs = observed(affect)
872-
for oexpr in obs_exprs
886+
for oexpr in obs_exprs
873887
invalid_vars = invalid_variables(sys, oexpr)
874888
if length(invalid_vars) > 0
875889
error("Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).")
@@ -880,7 +894,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
880894
obs_size = size.(obs_exprs) # we will generate a work buffer of a ComponentArray that maps obs_syms to arrays of size obs_size
881895

882896
mod_exprs = modified(affect)
883-
for mexpr in mod_exprs
897+
for mexpr in mod_exprs
884898
if !is_observed(sys, mexpr) && parameter_index(sys, mexpr) === nothing
885899
error("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.")
886900
end
@@ -891,37 +905,50 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
891905
end
892906
mod_syms = modified_syms(affect)
893907
mod_syms, mod_exprs = check_dups(mod_syms, mod_exprs)
894-
_, mod_og_val_fun = build_explicit_observed_function(sys, mod_exprs; return_inplace=true)
908+
_, mod_og_val_fun = build_explicit_observed_function(
909+
sys, mod_exprs; return_inplace = true)
895910

896911
overlapping_syms = intersect(mod_syms, obs_syms)
897912
if length(overlapping_syms) > 0
898913
@warn "The symbols $overlapping_syms are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value."
899914
end
900915

901916
# sanity checks done! now build the data and update function for observed values
902-
mkzero(sz) = if sz === () 0.0 else zeros(sz) end
903-
_, obs_fun = build_explicit_observed_function(sys, reduce(vcat, Symbolics.scalarize.(obs_exprs); init = []); return_inplace=true)
904-
obs_component_array = ComponentArrays.ComponentArray(NamedTuple{(obs_syms..., )}(mkzero.(obs_size)))
917+
mkzero(sz) =
918+
if sz === ()
919+
0.0
920+
else
921+
zeros(sz)
922+
end
923+
_, obs_fun = build_explicit_observed_function(
924+
sys, reduce(vcat, Symbolics.scalarize.(obs_exprs); init = []);
925+
return_inplace = true)
926+
obs_component_array = ComponentArrays.ComponentArray(NamedTuple{(obs_syms...,)}(mkzero.(obs_size)))
905927

906928
# okay so now to generate the stuff to assign it back into the system
907929
# note that we reorder the componentarray to make the views coherent wrt the base array
908930
mod_pairs = mod_exprs .=> mod_syms
909931
mod_param_pairs = filter(v -> is_parameter(sys, v[1]), mod_pairs)
910932
mod_unk_pairs = filter(v -> !is_parameter(sys, v[1]), mod_pairs)
911-
_, mod_og_val_fun = build_explicit_observed_function(sys, reduce(vcat, [first.(mod_param_pairs); first.(mod_unk_pairs)]; init = []); return_inplace=true)
912-
upd_params_fun = setu(sys, reduce(vcat, Symbolics.scalarize.(first.(mod_param_pairs)); init = []))
913-
upd_unk_fun = setu(sys, reduce(vcat, Symbolics.scalarize.(first.(mod_unk_pairs)); init = []))
914-
915-
upd_component_array = ComponentArrays.ComponentArray(NamedTuple{([last.(mod_param_pairs); last.(mod_unk_pairs)]...,)}(
916-
[collect(mkzero(size(e)) for e in first.(mod_param_pairs));
933+
_, mod_og_val_fun = build_explicit_observed_function(
934+
sys, reduce(vcat, [first.(mod_param_pairs); first.(mod_unk_pairs)]; init = []);
935+
return_inplace = true)
936+
upd_params_fun = setu(
937+
sys, reduce(vcat, Symbolics.scalarize.(first.(mod_param_pairs)); init = []))
938+
upd_unk_fun = setu(
939+
sys, reduce(vcat, Symbolics.scalarize.(first.(mod_unk_pairs)); init = []))
940+
941+
upd_component_array = ComponentArrays.ComponentArray(NamedTuple{([last.(mod_param_pairs);
942+
last.(mod_unk_pairs)]...,)}(
943+
[collect(mkzero(size(e)) for e in first.(mod_param_pairs));
917944
collect(mkzero(size(e)) for e in first.(mod_unk_pairs))]))
918945
upd_params_view = view(upd_component_array, last.(mod_param_pairs))
919946
upd_unks_view = view(upd_component_array, last.(mod_unk_pairs))
920947
let user_affect = func(affect), ctx = context(affect)
921948
function (integ)
922949
# update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
923950
mod_og_val_fun(upd_component_array, integ.u, integ.p..., integ.t)
924-
951+
925952
# update the observed values
926953
obs_fun(obs_component_array, integ.u, integ.p..., integ.t)
927954

@@ -934,7 +961,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
934961
user_affect(upd_component_array, obs_component_array)
935962
elseif applicable(user_affect, upd_component_array)
936963
user_affect(upd_component_array)
937-
else
964+
else
938965
@error "User affect function $user_affect needs to implement one of the supported MutatingFunctionalAffect callback forms; see the MutatingFunctionalAffect docstring for more details"
939966
user_affect(upd_component_array, obs_component_array, integ, ctx) # this WILL error but it'll give a more sensible message
940967
end

0 commit comments

Comments
 (0)