Skip to content

Commit d6c0cd9

Browse files
committed
Run formatter
1 parent 8413edd commit d6c0cd9

File tree

2 files changed

+164
-108
lines changed

2 files changed

+164
-108
lines changed

src/systems/callbacks.jl

Lines changed: 61 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,25 @@ The affect function updates the value at `x` in `modified` to be the result of e
106106
ctx::Any
107107
end
108108

109-
MutatingFunctionalAffect(f::Function;
110-
observed::NamedTuple = NamedTuple{()}(()),
111-
modified::NamedTuple = NamedTuple{()}(()),
112-
ctx=nothing) = MutatingFunctionalAffect(f, collect(values(observed)), collect(keys(observed)), collect(values(modified)), collect(keys(modified)), ctx)
113-
MutatingFunctionalAffect(f::Function, modified::NamedTuple; observed::NamedTuple = NamedTuple{()}(()), ctx=nothing) =
114-
MutatingFunctionalAffect(f, observed=observed, modified=modified, ctx=ctx)
115-
MutatingFunctionalAffect(f::Function, modified::NamedTuple, observed::NamedTuple; ctx=nothing) =
116-
MutatingFunctionalAffect(f, observed=observed, modified=modified, ctx=ctx)
117-
MutatingFunctionalAffect(f::Function, modified::NamedTuple, observed::NamedTuple, ctx) =
118-
MutatingFunctionalAffect(f, observed=observed, modified=modified, ctx=ctx)
109+
function MutatingFunctionalAffect(f::Function;
110+
observed::NamedTuple = NamedTuple{()}(()),
111+
modified::NamedTuple = NamedTuple{()}(()),
112+
ctx = nothing)
113+
MutatingFunctionalAffect(f, collect(values(observed)), collect(keys(observed)),
114+
collect(values(modified)), collect(keys(modified)), ctx)
115+
end
116+
function MutatingFunctionalAffect(f::Function, modified::NamedTuple;
117+
observed::NamedTuple = NamedTuple{()}(()), ctx = nothing)
118+
MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx)
119+
end
120+
function MutatingFunctionalAffect(
121+
f::Function, modified::NamedTuple, observed::NamedTuple; ctx = nothing)
122+
MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx)
123+
end
124+
function MutatingFunctionalAffect(
125+
f::Function, modified::NamedTuple, observed::NamedTuple, ctx)
126+
MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx)
127+
end
119128

120129
func(f::MutatingFunctionalAffect) = f.f
121130
context(a::MutatingFunctionalAffect) = a.ctx
@@ -126,8 +135,9 @@ modified(a::MutatingFunctionalAffect) = a.modified
126135
modified_syms(a::MutatingFunctionalAffect) = a.mod_syms
127136

128137
function Base.:(==)(a1::MutatingFunctionalAffect, a2::MutatingFunctionalAffect)
129-
isequal(a1.f, a2.f) && isequal(a1.obs, a2.obs) && isequal(a1.modified, a2.modified) &&
130-
isequal(a1.obs_syms, a2.obs_syms) && isequal(a1.mod_syms, a2.mod_syms)&& isequal(a1.ctx, a2.ctx)
138+
isequal(a1.f, a2.f) && isequal(a1.obs, a2.obs) && isequal(a1.modified, a2.modified) &&
139+
isequal(a1.obs_syms, a2.obs_syms) && isequal(a1.mod_syms, a2.mod_syms) &&
140+
isequal(a1.ctx, a2.ctx)
131141
end
132142

133143
function Base.hash(a::MutatingFunctionalAffect, s::UInt)
@@ -237,11 +247,13 @@ SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2])
237247
SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough
238248
function SymbolicContinuousCallback(eqs::Equation, affect = NULL_AFFECT;
239249
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
240-
SymbolicContinuousCallback(eqs=[eqs], affect=affect, affect_neg=affect_neg, rootfind=rootfind)
250+
SymbolicContinuousCallback(
251+
eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind)
241252
end
242253
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT;
243254
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
244-
SymbolicContinuousCallback(eqs=eqs, affect=affect, affect_neg=affect_neg, rootfind=rootfind)
255+
SymbolicContinuousCallback(
256+
eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind)
245257
end
246258

247259
SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb]
@@ -765,10 +777,13 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
765777
end
766778
end
767779

768-
invalid_variables(sys, expr) = filter(x -> !any(isequal(x), all_symbols(sys)), reduce(vcat, vars(expr); init=[]))
769-
function unassignable_variables(sys, expr)
780+
function invalid_variables(sys, expr)
781+
filter(x -> !any(isequal(x), all_symbols(sys)), reduce(vcat, vars(expr); init = []))
782+
end
783+
function unassignable_variables(sys, expr)
770784
assignable_syms = vcat(unknowns(sys), parameters(sys))
771-
return filter(x -> !any(isequal(x), assignable_syms), reduce(vcat, vars(expr); init=[]))
785+
return filter(
786+
x -> !any(isequal(x), assignable_syms), reduce(vcat, vars(expr); init = []))
772787
end
773788

774789
function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwargs...)
@@ -781,7 +796,8 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
781796
=#
782797
function check_dups(syms, exprs) # = (syms_dedup, exprs_dedup)
783798
seen = Set{Symbol}()
784-
syms_dedup = []; exprs_dedup = []
799+
syms_dedup = []
800+
exprs_dedup = []
785801
for (sym, exp) in Iterators.zip(syms, exprs)
786802
if !in(sym, seen)
787803
push!(syms_dedup, sym)
@@ -795,7 +811,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
795811
end
796812

797813
obs_exprs = observed(affect)
798-
for oexpr in obs_exprs
814+
for oexpr in obs_exprs
799815
invalid_vars = invalid_variables(sys, oexpr)
800816
if length(invalid_vars) > 0
801817
error("Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).")
@@ -806,7 +822,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
806822
obs_size = size.(obs_exprs) # we will generate a work buffer of a ComponentArray that maps obs_syms to arrays of size obs_size
807823

808824
mod_exprs = modified(affect)
809-
for mexpr in mod_exprs
825+
for mexpr in mod_exprs
810826
if !is_observed(sys, mexpr) && parameter_index(sys, mexpr) === nothing
811827
error("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.")
812828
end
@@ -817,37 +833,50 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
817833
end
818834
mod_syms = modified_syms(affect)
819835
mod_syms, mod_exprs = check_dups(mod_syms, mod_exprs)
820-
_, mod_og_val_fun = build_explicit_observed_function(sys, mod_exprs; return_inplace=true)
836+
_, mod_og_val_fun = build_explicit_observed_function(
837+
sys, mod_exprs; return_inplace = true)
821838

822839
overlapping_syms = intersect(mod_syms, obs_syms)
823840
if length(overlapping_syms) > 0
824841
@warn "The symbols $overlapping_syms are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value."
825842
end
826843

827844
# sanity checks done! now build the data and update function for observed values
828-
mkzero(sz) = if sz === () 0.0 else zeros(sz) end
829-
_, obs_fun = build_explicit_observed_function(sys, reduce(vcat, Symbolics.scalarize.(obs_exprs); init = []); return_inplace=true)
830-
obs_component_array = ComponentArrays.ComponentArray(NamedTuple{(obs_syms..., )}(mkzero.(obs_size)))
845+
mkzero(sz) =
846+
if sz === ()
847+
0.0
848+
else
849+
zeros(sz)
850+
end
851+
_, obs_fun = build_explicit_observed_function(
852+
sys, reduce(vcat, Symbolics.scalarize.(obs_exprs); init = []);
853+
return_inplace = true)
854+
obs_component_array = ComponentArrays.ComponentArray(NamedTuple{(obs_syms...,)}(mkzero.(obs_size)))
831855

832856
# okay so now to generate the stuff to assign it back into the system
833857
# note that we reorder the componentarray to make the views coherent wrt the base array
834858
mod_pairs = mod_exprs .=> mod_syms
835859
mod_param_pairs = filter(v -> is_parameter(sys, v[1]), mod_pairs)
836860
mod_unk_pairs = filter(v -> !is_parameter(sys, v[1]), mod_pairs)
837-
_, mod_og_val_fun = build_explicit_observed_function(sys, reduce(vcat, [first.(mod_param_pairs); first.(mod_unk_pairs)]; init = []); return_inplace=true)
838-
upd_params_fun = setu(sys, reduce(vcat, Symbolics.scalarize.(first.(mod_param_pairs)); init = []))
839-
upd_unk_fun = setu(sys, reduce(vcat, Symbolics.scalarize.(first.(mod_unk_pairs)); init = []))
840-
841-
upd_component_array = ComponentArrays.ComponentArray(NamedTuple{([last.(mod_param_pairs); last.(mod_unk_pairs)]...,)}(
842-
[collect(mkzero(size(e)) for e in first.(mod_param_pairs));
861+
_, mod_og_val_fun = build_explicit_observed_function(
862+
sys, reduce(vcat, [first.(mod_param_pairs); first.(mod_unk_pairs)]; init = []);
863+
return_inplace = true)
864+
upd_params_fun = setu(
865+
sys, reduce(vcat, Symbolics.scalarize.(first.(mod_param_pairs)); init = []))
866+
upd_unk_fun = setu(
867+
sys, reduce(vcat, Symbolics.scalarize.(first.(mod_unk_pairs)); init = []))
868+
869+
upd_component_array = ComponentArrays.ComponentArray(NamedTuple{([last.(mod_param_pairs);
870+
last.(mod_unk_pairs)]...,)}(
871+
[collect(mkzero(size(e)) for e in first.(mod_param_pairs));
843872
collect(mkzero(size(e)) for e in first.(mod_unk_pairs))]))
844873
upd_params_view = view(upd_component_array, last.(mod_param_pairs))
845874
upd_unks_view = view(upd_component_array, last.(mod_unk_pairs))
846875
let user_affect = func(affect), ctx = context(affect)
847876
function (integ)
848877
# update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
849878
mod_og_val_fun(upd_component_array, integ.u, integ.p..., integ.t)
850-
879+
851880
# update the observed values
852881
obs_fun(obs_component_array, integ.u, integ.p..., integ.t)
853882

@@ -860,7 +889,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
860889
user_affect(upd_component_array, obs_component_array)
861890
elseif applicable(user_affect, upd_component_array)
862891
user_affect(upd_component_array)
863-
else
892+
else
864893
@error "User affect function $user_affect needs to implement one of the supported MutatingFunctionalAffect callback forms; see the MutatingFunctionalAffect docstring for more details"
865894
user_affect(upd_component_array, obs_component_array, integ, ctx) # this WILL error but it'll give a more sensible message
866895
end

0 commit comments

Comments
 (0)