Skip to content

Commit aecd59b

Browse files
committed
Formatter
1 parent 0654909 commit aecd59b

File tree

2 files changed

+107
-77
lines changed

2 files changed

+107
-77
lines changed

src/systems/callbacks.jl

Lines changed: 97 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -116,29 +116,33 @@ function ImperativeAffect(f::Function;
116116
modified::NamedTuple = NamedTuple{()}(()),
117117
ctx = nothing,
118118
skip_checks = false)
119-
ImperativeAffect(f,
119+
ImperativeAffect(f,
120120
collect(values(observed)), collect(keys(observed)),
121-
collect(values(modified)), collect(keys(modified)),
121+
collect(values(modified)), collect(keys(modified)),
122122
ctx, skip_checks)
123123
end
124124
function ImperativeAffect(f::Function, modified::NamedTuple;
125-
observed::NamedTuple = NamedTuple{()}(()), ctx = nothing, skip_checks=false)
126-
ImperativeAffect(f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
125+
observed::NamedTuple = NamedTuple{()}(()), ctx = nothing, skip_checks = false)
126+
ImperativeAffect(
127+
f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
127128
end
128129
function ImperativeAffect(
129-
f::Function, modified::NamedTuple, observed::NamedTuple; ctx = nothing, skip_checks=false)
130-
ImperativeAffect(f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
130+
f::Function, modified::NamedTuple, observed::NamedTuple; ctx = nothing, skip_checks = false)
131+
ImperativeAffect(
132+
f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
131133
end
132134
function ImperativeAffect(
133-
f::Function, modified::NamedTuple, observed::NamedTuple, ctx; skip_checks=false)
134-
ImperativeAffect(f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
135+
f::Function, modified::NamedTuple, observed::NamedTuple, ctx; skip_checks = false)
136+
ImperativeAffect(
137+
f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
135138
end
136139

137-
function Base.show(io::IO, mfa::ImperativeAffect)
138-
obs_vals = join(map((ob,nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ")
139-
mod_vals = join(map((md,nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ")
140+
function Base.show(io::IO, mfa::ImperativeAffect)
141+
obs_vals = join(map((ob, nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ")
142+
mod_vals = join(map((md, nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ")
140143
affect = mfa.f
141-
print(io, "ImperativeAffect(observed: [$obs_vals], modified: [$mod_vals], affect:$affect)")
144+
print(io,
145+
"ImperativeAffect(observed: [$obs_vals], modified: [$mod_vals], affect:$affect)")
142146
end
143147
func(f::ImperativeAffect) = f.f
144148
context(a::ImperativeAffect) = a.ctx
@@ -234,24 +238,25 @@ struct SymbolicContinuousCallback
234238
affect_neg::Union{Vector{Equation}, FunctionalAffect, ImperativeAffect, Nothing}
235239
rootfind::SciMLBase.RootfindOpt
236240
reinitializealg::SciMLBase.DAEInitializationAlgorithm
237-
function SymbolicContinuousCallback(;
238-
eqs::Vector{Equation},
239-
affect = NULL_AFFECT,
240-
affect_neg = affect,
241-
rootfind = SciMLBase.LeftRootFind,
242-
initialize=NULL_AFFECT,
243-
finalize=NULL_AFFECT,
244-
reinitializealg=SciMLBase.CheckInit())
245-
new(eqs, initialize, finalize, make_affect(affect), make_affect(affect_neg), rootfind, reinitializealg)
241+
function SymbolicContinuousCallback(;
242+
eqs::Vector{Equation},
243+
affect = NULL_AFFECT,
244+
affect_neg = affect,
245+
rootfind = SciMLBase.LeftRootFind,
246+
initialize = NULL_AFFECT,
247+
finalize = NULL_AFFECT,
248+
reinitializealg = SciMLBase.CheckInit())
249+
new(eqs, initialize, finalize, make_affect(affect),
250+
make_affect(affect_neg), rootfind, reinitializealg)
246251
end # Default affect to nothing
247252
end
248253
make_affect(affect) = affect
249254
make_affect(affect::Tuple) = FunctionalAffect(affect...)
250255
make_affect(affect::NamedTuple) = FunctionalAffect(; affect...)
251256

252257
function Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback)
253-
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) &&
254-
isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize) &&
258+
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) &&
259+
isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize) &&
255260
isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind)
256261
end
257262
Base.isempty(cb::SymbolicContinuousCallback) = isempty(cb.eqs)
@@ -266,10 +271,9 @@ function Base.hash(cb::SymbolicContinuousCallback, s::UInt)
266271
hash(cb.rootfind, s)
267272
end
268273

269-
270274
function Base.show(io::IO, cb::SymbolicContinuousCallback)
271275
indent = get(io, :indent, 0)
272-
iio = IOContext(io, :indent => indent+1)
276+
iio = IOContext(io, :indent => indent + 1)
273277
print(io, "SymbolicContinuousCallback(")
274278
print(iio, "Equations:")
275279
show(iio, equations(cb))
@@ -298,7 +302,7 @@ end
298302

299303
function Base.show(io::IO, mime::MIME"text/plain", cb::SymbolicContinuousCallback)
300304
indent = get(io, :indent, 0)
301-
iio = IOContext(io, :indent => indent+1)
305+
iio = IOContext(io, :indent => indent + 1)
302306
println(io, "SymbolicContinuousCallback:")
303307
println(iio, "Equations:")
304308
show(iio, mime, equations(cb))
@@ -338,14 +342,18 @@ end # wrap eq in vector
338342
SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2])
339343
SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough
340344
function SymbolicContinuousCallback(eqs::Equation, affect = NULL_AFFECT;
341-
affect_neg = affect, rootfind = SciMLBase.LeftRootFind, initialize = NULL_AFFECT, finalize = NULL_AFFECT)
345+
affect_neg = affect, rootfind = SciMLBase.LeftRootFind,
346+
initialize = NULL_AFFECT, finalize = NULL_AFFECT)
342347
SymbolicContinuousCallback(
343-
eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind, initialize=initialize, finalize=finalize)
348+
eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind,
349+
initialize = initialize, finalize = finalize)
344350
end
345351
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT;
346-
affect_neg = affect, rootfind = SciMLBase.LeftRootFind, initialize = NULL_AFFECT, finalize = NULL_AFFECT)
352+
affect_neg = affect, rootfind = SciMLBase.LeftRootFind,
353+
initialize = NULL_AFFECT, finalize = NULL_AFFECT)
347354
SymbolicContinuousCallback(
348-
eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind, initialize=initialize, finalize=finalize)
355+
eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind,
356+
initialize = initialize, finalize = finalize)
349357
end
350358

351359
SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb]
@@ -385,8 +393,10 @@ function finalize_affects(cbs::Vector{SymbolicContinuousCallback})
385393
end
386394

387395
reinitialization_alg(cb::SymbolicContinuousCallback) = cb.reinitializealg
388-
reinitialization_algs(cbs::Vector{SymbolicContinuousCallback}) =
389-
mapreduce(reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[])
396+
function reinitialization_algs(cbs::Vector{SymbolicContinuousCallback})
397+
mapreduce(
398+
reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[])
399+
end
390400

391401
namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af]
392402
namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s)
@@ -436,7 +446,8 @@ struct SymbolicDiscreteCallback
436446
affects::Any
437447
reinitializealg::SciMLBase.DAEInitializationAlgorithm
438448

439-
function SymbolicDiscreteCallback(condition, affects = NULL_AFFECT, reinitializealg=SciMLBase.CheckInit())
449+
function SymbolicDiscreteCallback(
450+
condition, affects = NULL_AFFECT, reinitializealg = SciMLBase.CheckInit())
440451
c = scalarize_condition(condition)
441452
a = scalarize_affects(affects)
442453
new(c, a, reinitializealg)
@@ -498,8 +509,10 @@ function affects(cbs::Vector{SymbolicDiscreteCallback})
498509
end
499510

500511
reinitialization_alg(cb::SymbolicDiscreteCallback) = cb.reinitializealg
501-
reinitialization_algs(cbs::Vector{SymbolicDiscreteCallback}) =
502-
mapreduce(reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[])
512+
function reinitialization_algs(cbs::Vector{SymbolicDiscreteCallback})
513+
mapreduce(
514+
reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[])
515+
end
503516

504517
function namespace_callback(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCallback
505518
af = affects(cb)
@@ -781,7 +794,8 @@ function generate_single_rootfinding_callback(
781794
end
782795
end
783796

784-
user_initfun = isnothing(affect_function.initialize) ? SciMLBase.INITIALIZE_DEFAULT : (c, u, t, i) -> affect_function.initialize(i)
797+
user_initfun = isnothing(affect_function.initialize) ? SciMLBase.INITIALIZE_DEFAULT :
798+
(c, u, t, i) -> affect_function.initialize(i)
785799
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing &&
786800
(save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing
787801
initfn = let save_idxs = save_idxs
@@ -795,17 +809,19 @@ function generate_single_rootfinding_callback(
795809
else
796810
initfn = user_initfun
797811
end
798-
812+
799813
return ContinuousCallback(
800-
cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind,
801-
initialize = initfn,
802-
finalize = isnothing(affect_function.finalize) ? SciMLBase.FINALIZE_DEFAULT : (c, u, t, i) -> affect_function.finalize(i),
814+
cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind,
815+
initialize = initfn,
816+
finalize = isnothing(affect_function.finalize) ? SciMLBase.FINALIZE_DEFAULT :
817+
(c, u, t, i) -> affect_function.finalize(i),
803818
initializealg = reinitialization_alg(cb))
804819
end
805820

806821
function generate_vector_rootfinding_callback(
807822
cbs, sys::AbstractODESystem, dvs = unknowns(sys),
808-
ps = parameters(sys); rootfind = SciMLBase.RightRootFind, reinitialization = SciMLBase.CheckInit(), kwargs...)
823+
ps = parameters(sys); rootfind = SciMLBase.RightRootFind,
824+
reinitialization = SciMLBase.CheckInit(), kwargs...)
809825
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
810826
num_eqs = length.(eqs)
811827
# fuse equations to create VectorContinuousCallback
@@ -821,11 +837,12 @@ function generate_vector_rootfinding_callback(
821837
sys, rhss, dvs, ps; expression = Val{false}, kwargs...)
822838

823839
affect_functions = @NamedTuple{
824-
affect::Function,
825-
affect_neg::Union{Function, Nothing},
826-
initialize::Union{Function, Nothing},
840+
affect::Function,
841+
affect_neg::Union{Function, Nothing},
842+
initialize::Union{Function, Nothing},
827843
finalize::Union{Function, Nothing}}[
828-
compile_affect_fn(cb, sys, dvs, ps, kwargs) for cb in cbs]
844+
compile_affect_fn(cb, sys, dvs, ps, kwargs)
845+
for cb in cbs]
829846
cond = function (out, u, t, integ)
830847
rf_ip(out, u, parameter_values(integ), t)
831848
end
@@ -861,17 +878,20 @@ function generate_vector_rootfinding_callback(
861878
if isnothing(func)
862879
continue
863880
else
864-
func(integ)
881+
func(integ)
865882
end
866883
end
867884
end
868885
end
869886
end
870887
end
871-
initialize = handle_optional_setup_fn(map(fn -> fn.initialize, affect_functions), SciMLBase.INITIALIZE_DEFAULT)
872-
finalize = handle_optional_setup_fn(map(fn -> fn.finalize, affect_functions), SciMLBase.FINALIZE_DEFAULT)
888+
initialize = handle_optional_setup_fn(
889+
map(fn -> fn.initialize, affect_functions), SciMLBase.INITIALIZE_DEFAULT)
890+
finalize = handle_optional_setup_fn(
891+
map(fn -> fn.finalize, affect_functions), SciMLBase.FINALIZE_DEFAULT)
873892
return VectorContinuousCallback(
874-
cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initialize, finalize = finalize, initializealg = reinitialization)
893+
cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initialize,
894+
finalize = finalize, initializealg = reinitialization)
875895
end
876896

877897
"""
@@ -881,8 +901,8 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
881901
eq_aff = affects(cb)
882902
eq_neg_aff = affect_negs(cb)
883903
affect = compile_affect(eq_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
884-
function compile_optional_affect(aff, default=nothing)
885-
if isnothing(aff) || aff==default
904+
function compile_optional_affect(aff, default = nothing)
905+
if isnothing(aff) || aff == default
886906
return nothing
887907
else
888908
return compile_affect(aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
@@ -918,21 +938,23 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
918938
# groupby would be very useful here, but alas
919939
cb_classes = Dict{
920940
@NamedTuple{
921-
rootfind::SciMLBase.RootfindOpt,
941+
rootfind::SciMLBase.RootfindOpt,
922942
reinitialization::SciMLBase.DAEInitializationAlgorithm}, Vector{SymbolicContinuousCallback}}()
923943
for cb in cbs
924944
push!(
925-
get!(() -> SymbolicContinuousCallback[], cb_classes, (
926-
rootfind = cb.rootfind,
927-
reinitialization = reinitialization_alg(cb))),
945+
get!(() -> SymbolicContinuousCallback[], cb_classes,
946+
(
947+
rootfind = cb.rootfind,
948+
reinitialization = reinitialization_alg(cb))),
928949
cb)
929950
end
930951

931952
# generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order
932953
compiled_callbacks = map(collect(pairs(sort!(
933954
OrderedDict(cb_classes); by = p -> p.rootfind)))) do (equiv_class, cbs_in_class)
934955
return generate_vector_rootfinding_callback(
935-
cbs_in_class, sys, dvs, ps; rootfind = equiv_class.rootfind, reinitialization=equiv_class.reinitialization, kwargs...)
956+
cbs_in_class, sys, dvs, ps; rootfind = equiv_class.rootfind,
957+
reinitialization = equiv_class.reinitialization, kwargs...)
936958
end
937959
if length(compiled_callbacks) == 1
938960
return compiled_callbacks[]
@@ -984,29 +1006,34 @@ function invalid_variables(sys, expr)
9841006
filter(x -> !any(isequal(x), all_symbols(sys)), reduce(vcat, vars(expr); init = []))
9851007
end
9861008
function unassignable_variables(sys, expr)
987-
assignable_syms = reduce(vcat, Symbolics.scalarize.(vcat(unknowns(sys), parameters(sys))); init=[])
1009+
assignable_syms = reduce(
1010+
vcat, Symbolics.scalarize.(vcat(unknowns(sys), parameters(sys))); init = [])
9881011
written = reduce(vcat, Symbolics.scalarize.(vars(expr)); init = [])
9891012
return filter(
9901013
x -> !any(isequal(x), assignable_syms), written)
9911014
end
9921015

993-
@generated function _generated_writeback(integ, setters::NamedTuple{NS1,<:Tuple}, values::NamedTuple{NS2, <:Tuple}) where {NS1, NS2}
1016+
@generated function _generated_writeback(integ, setters::NamedTuple{NS1, <:Tuple},
1017+
values::NamedTuple{NS2, <:Tuple}) where {NS1, NS2}
9941018
setter_exprs = []
995-
for name in NS2
1019+
for name in NS2
9961020
if !(name in NS1)
9971021
missing_name = "Tried to write back to $name from affect; only declared states ($NS1) may be written to."
9981022
error(missing_name)
9991023
end
10001024
push!(setter_exprs, :(setters.$name(integ, values.$name)))
10011025
end
1002-
return :(begin $(setter_exprs...) end)
1026+
return :(begin
1027+
$(setter_exprs...)
1028+
end)
10031029
end
10041030

10051031
function check_assignable(sys, sym)
10061032
if symbolic_type(sym) == ScalarSymbolic()
10071033
is_variable(sys, sym) || is_parameter(sys, sym)
10081034
elseif symbolic_type(sym) == ArraySymbolic()
1009-
is_variable(sys, sym) || is_parameter(sys, sym) || all(x -> check_assignable(sys, x), collect(sym))
1035+
is_variable(sys, sym) || is_parameter(sys, sym) ||
1036+
all(x -> check_assignable(sys, x), collect(sym))
10101037
elseif sym isa Union{AbstractArray, Tuple}
10111038
all(x -> check_assignable(sys, x), sym)
10121039
else
@@ -1084,13 +1111,13 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs.
10841111

10851112
# okay so now to generate the stuff to assign it back into the system
10861113
mod_pairs = mod_exprs .=> mod_syms
1087-
mod_names = (mod_syms..., )
1114+
mod_names = (mod_syms...,)
10881115
mod_og_val_fun = build_explicit_observed_function(
10891116
sys, Symbolics.scalarize.(first.(mod_pairs));
10901117
array_type = :tuple)
10911118

10921119
upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,))
1093-
1120+
10941121
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
10951122
save_idxs = get(ic.callback_to_clocks, cb, Int[])
10961123
else
@@ -1104,10 +1131,12 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs.
11041131
upd_component_array = NamedTuple{mod_names}(modvals)
11051132

11061133
# update the observed values
1107-
obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun(integ.u, integ.p, integ.t))
1134+
obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun(
1135+
integ.u, integ.p, integ.t))
11081136

11091137
# let the user do their thing
1110-
modvals = if applicable(user_affect, upd_component_array, obs_component_array, ctx, integ)
1138+
modvals = if applicable(
1139+
user_affect, upd_component_array, obs_component_array, ctx, integ)
11111140
user_affect(upd_component_array, obs_component_array, ctx, integ)
11121141
elseif applicable(user_affect, upd_component_array, obs_component_array, ctx)
11131142
user_affect(upd_component_array, obs_component_array, ctx)
@@ -1122,15 +1151,16 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs.
11221151

11231152
# write the new values back to the integrator
11241153
_generated_writeback(integ, upd_funs, modvals)
1125-
1154+
11261155
for idx in save_idxs
11271156
SciMLBase.save_discretes!(integ, idx)
11281157
end
11291158
end
11301159
end
11311160
end
11321161

1133-
function compile_affect(affect::Union{FunctionalAffect, ImperativeAffect}, cb, sys, dvs, ps; kwargs...)
1162+
function compile_affect(
1163+
affect::Union{FunctionalAffect, ImperativeAffect}, cb, sys, dvs, ps; kwargs...)
11341164
compile_user_affect(affect, cb, sys, dvs, ps; kwargs...)
11351165
end
11361166

0 commit comments

Comments
 (0)