Skip to content

Commit baf6d07

Browse files
committed
Run formatter
1 parent 0164062 commit baf6d07

File tree

1 file changed

+52
-25
lines changed

1 file changed

+52
-25
lines changed

src/systems/callbacks.jl

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ struct SymbolicContinuousCallback
112112
affect::Union{Vector{Equation}, FunctionalAffect}
113113
affect_neg::Union{Vector{Equation}, FunctionalAffect, Nothing}
114114
rootfind::SciMLBase.RootfindOpt
115-
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT, affect_neg = affect, rootfind=SciMLBase.LeftRootFind)
115+
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT,
116+
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
116117
new(eqs, make_affect(affect), make_affect(affect_neg), rootfind)
117118
end # Default affect to nothing
118119
end
@@ -121,13 +122,15 @@ make_affect(affect::Tuple) = FunctionalAffect(affect...)
121122
make_affect(affect::NamedTuple) = FunctionalAffect(; affect...)
122123

123124
function Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback)
124-
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) && isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind)
125+
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) &&
126+
isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind)
125127
end
126128
Base.isempty(cb::SymbolicContinuousCallback) = isempty(cb.eqs)
127129
function Base.hash(cb::SymbolicContinuousCallback, s::UInt)
128130
s = foldr(hash, cb.eqs, init = s)
129131
s = cb.affect isa AbstractVector ? foldr(hash, cb.affect, init = s) : hash(cb.affect, s)
130-
s = cb.affect_neg isa AbstractVector ? foldr(hash, cb.affect_neg, init = s) : hash(cb.affect_neg, s)
132+
s = cb.affect_neg isa AbstractVector ? foldr(hash, cb.affect_neg, init = s) :
133+
hash(cb.affect_neg, s)
131134
hash(cb.rootfind, s)
132135
end
133136

@@ -143,8 +146,14 @@ function SymbolicContinuousCallback(args...)
143146
end # wrap eq in vector
144147
SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2])
145148
SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough
146-
SymbolicContinuousCallback(eqs::Equation, affect = NULL_AFFECT; affect_neg = affect, rootfind=SciMLBase.LeftRootFind) = SymbolicContinuousCallback([eqs], affect, affect_neg, rootfind)
147-
SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT; affect_neg = affect, rootfind=SciMLBase.LeftRootFind) = SymbolicContinuousCallback(eqs, affect, affect_neg, rootfind)
149+
function SymbolicContinuousCallback(eqs::Equation, affect = NULL_AFFECT;
150+
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
151+
SymbolicContinuousCallback([eqs], affect, affect_neg, rootfind)
152+
end
153+
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT;
154+
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
155+
SymbolicContinuousCallback(eqs, affect, affect_neg, rootfind)
156+
end
148157

149158
SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb]
150159
SymbolicContinuousCallbacks(cbs::Vector{<:SymbolicContinuousCallback}) = cbs
@@ -510,13 +519,15 @@ end
510519
Generate a single rootfinding callback; this happens if there is only one equation in `cbs` passed to
511520
generate_rootfinding_callback and thus we can produce a ContinuousCallback instead of a VectorContinuousCallback.
512521
"""
513-
function generate_single_rootfinding_callback(eq, cb, sys::AbstractODESystem, dvs = unknowns(sys),
514-
ps = full_parameters(sys); kwargs...)
522+
function generate_single_rootfinding_callback(
523+
eq, cb, sys::AbstractODESystem, dvs = unknowns(sys),
524+
ps = full_parameters(sys); kwargs...)
515525
if !isequal(eq.lhs, 0)
516526
eq = 0 ~ eq.lhs - eq.rhs
517527
end
518-
519-
rf_oop, rf_ip = generate_custom_function(sys, [eq.rhs], dvs, ps; expression = Val{false}, kwargs...)
528+
529+
rf_oop, rf_ip = generate_custom_function(
530+
sys, [eq.rhs], dvs, ps; expression = Val{false}, kwargs...)
520531
affect_function = compile_affect_fn(cb, sys, dvs, ps, kwargs)
521532
cond = function (u, t, integ)
522533
if DiffEqBase.isinplace(integ.sol.prob)
@@ -527,11 +538,13 @@ function generate_single_rootfinding_callback(eq, cb, sys::AbstractODESystem, dv
527538
rf_oop(u, parameter_values(integ), t)
528539
end
529540
end
530-
return ContinuousCallback(cond, affect_function.affect, affect_function.affect_neg, rootfind=cb.rootfind)
541+
return ContinuousCallback(
542+
cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind)
531543
end
532544

533-
function generate_vector_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
534-
ps = full_parameters(sys); rootfind=SciMLBase.RightRootFind, kwargs...)
545+
function generate_vector_rootfinding_callback(
546+
cbs, sys::AbstractODESystem, dvs = unknowns(sys),
547+
ps = full_parameters(sys); rootfind = SciMLBase.RightRootFind, kwargs...)
535548
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
536549
num_eqs = length.(eqs)
537550
# fuse equations to create VectorContinuousCallback
@@ -543,9 +556,16 @@ function generate_vector_rootfinding_callback(cbs, sys::AbstractODESystem, dvs =
543556
end
544557

545558
rhss = map(x -> x.rhs, eqs)
546-
_, rf_ip = generate_custom_function(sys, rhss, dvs, ps; expression = Val{false}, kwargs...)
547-
548-
affect_functions = @NamedTuple{affect::Function, affect_neg::Union{Function, Nothing}}[compile_affect_fn(cb, sys, dvs, ps, kwargs) for cb in cbs]
559+
_, rf_ip = generate_custom_function(
560+
sys, rhss, dvs, ps; expression = Val{false}, kwargs...)
561+
562+
affect_functions = @NamedTuple{affect::Function, affect_neg::Union{Function, Nothing}}[compile_affect_fn(
563+
cb,
564+
sys,
565+
dvs,
566+
ps,
567+
kwargs)
568+
for cb in cbs]
549569
cond = function (out, u, t, integ)
550570
rf_ip(out, u, parameter_values(integ), t)
551571
end
@@ -571,7 +591,8 @@ function generate_vector_rootfinding_callback(cbs, sys::AbstractODESystem, dvs =
571591
affect_neg(integ)
572592
end
573593
end
574-
return VectorContinuousCallback(cond, affect, affect_neg, length(eqs), rootfind=rootfind)
594+
return VectorContinuousCallback(
595+
cond, affect, affect_neg, length(eqs), rootfind = rootfind)
575596
end
576597

577598
"""
@@ -582,13 +603,14 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
582603
eq_neg_aff = affect_negs(cb)
583604
affect = compile_affect(eq_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
584605
if eq_neg_aff === eq_aff
585-
affect_neg = affect
606+
affect_neg = affect
586607
elseif isnothing(eq_neg_aff)
587608
affect_neg = nothing
588609
else
589-
affect_neg = compile_affect(eq_neg_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
610+
affect_neg = compile_affect(
611+
eq_neg_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
590612
end
591-
(affect=affect, affect_neg=affect_neg)
613+
(affect = affect, affect_neg = affect_neg)
592614
end
593615

594616
function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
@@ -609,19 +631,24 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
609631

610632
# group the cbs by what rootfind op they use
611633
# groupby would be very useful here, but alas
612-
cb_classes = Dict{@NamedTuple{rootfind::SciMLBase.RootfindOpt}, Vector{SymbolicContinuousCallback}}()
613-
for cb in cbs
614-
push!(get!(() -> SymbolicContinuousCallback[], cb_classes, (rootfind=cb.rootfind, )), cb)
634+
cb_classes = Dict{
635+
@NamedTuple{rootfind::SciMLBase.RootfindOpt}, Vector{SymbolicContinuousCallback}}()
636+
for cb in cbs
637+
push!(
638+
get!(() -> SymbolicContinuousCallback[], cb_classes, (rootfind = cb.rootfind,)),
639+
cb)
615640
end
616641

617642
# generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order
618-
compiled_callbacks = map(collect(pairs(sort!(OrderedDict(cb_classes); by=p->p.rootfind)))) do (equiv_class, cbs_in_class)
619-
return generate_vector_rootfinding_callback(cbs_in_class, sys, dvs, ps; rootfind=equiv_class.rootfind, kwargs...)
643+
compiled_callbacks = map(collect(pairs(sort!(
644+
OrderedDict(cb_classes); by = p -> p.rootfind)))) do (equiv_class, cbs_in_class)
645+
return generate_vector_rootfinding_callback(
646+
cbs_in_class, sys, dvs, ps; rootfind = equiv_class.rootfind, kwargs...)
620647
end
621648
if length(compiled_callbacks) == 1
622649
return compiled_callbacks[]
623650
else
624-
return CallbackSet(compiled_callbacks...)
651+
return CallbackSet(compiled_callbacks...)
625652
end
626653
end
627654

0 commit comments

Comments
 (0)