@@ -112,7 +112,8 @@ struct SymbolicContinuousCallback
112
112
affect:: Union{Vector{Equation}, FunctionalAffect}
113
113
affect_neg:: Union{Vector{Equation}, FunctionalAffect, Nothing}
114
114
rootfind:: SciMLBase.RootfindOpt
115
- function SymbolicContinuousCallback (eqs:: Vector{Equation} , affect = NULL_AFFECT, affect_neg = affect, rootfind= SciMLBase. LeftRootFind)
115
+ function SymbolicContinuousCallback (eqs:: Vector{Equation} , affect = NULL_AFFECT,
116
+ affect_neg = affect, rootfind = SciMLBase. LeftRootFind)
116
117
new (eqs, make_affect (affect), make_affect (affect_neg), rootfind)
117
118
end # Default affect to nothing
118
119
end
@@ -121,13 +122,15 @@ make_affect(affect::Tuple) = FunctionalAffect(affect...)
121
122
make_affect (affect:: NamedTuple ) = FunctionalAffect (; affect... )
122
123
123
124
function Base.:(== )(e1:: SymbolicContinuousCallback , e2:: SymbolicContinuousCallback )
124
- isequal (e1. eqs, e2. eqs) && isequal (e1. affect, e2. affect) && isequal (e1. affect_neg, e2. affect_neg) && isequal (e1. rootfind, e2. rootfind)
125
+ isequal (e1. eqs, e2. eqs) && isequal (e1. affect, e2. affect) &&
126
+ isequal (e1. affect_neg, e2. affect_neg) && isequal (e1. rootfind, e2. rootfind)
125
127
end
126
128
Base. isempty (cb:: SymbolicContinuousCallback ) = isempty (cb. eqs)
127
129
function Base. hash (cb:: SymbolicContinuousCallback , s:: UInt )
128
130
s = foldr (hash, cb. eqs, init = s)
129
131
s = cb. affect isa AbstractVector ? foldr (hash, cb. affect, init = s) : hash (cb. affect, s)
130
- s = cb. affect_neg isa AbstractVector ? foldr (hash, cb. affect_neg, init = s) : hash (cb. affect_neg, s)
132
+ s = cb. affect_neg isa AbstractVector ? foldr (hash, cb. affect_neg, init = s) :
133
+ hash (cb. affect_neg, s)
131
134
hash (cb. rootfind, s)
132
135
end
133
136
@@ -143,8 +146,14 @@ function SymbolicContinuousCallback(args...)
143
146
end # wrap eq in vector
144
147
SymbolicContinuousCallback (p:: Pair ) = SymbolicContinuousCallback (p[1 ], p[2 ])
145
148
SymbolicContinuousCallback (cb:: SymbolicContinuousCallback ) = cb # passthrough
146
- SymbolicContinuousCallback (eqs:: Equation , affect = NULL_AFFECT; affect_neg = affect, rootfind= SciMLBase. LeftRootFind) = SymbolicContinuousCallback ([eqs], affect, affect_neg, rootfind)
147
- SymbolicContinuousCallback (eqs:: Vector{Equation} , affect = NULL_AFFECT; affect_neg = affect, rootfind= SciMLBase. LeftRootFind) = SymbolicContinuousCallback (eqs, affect, affect_neg, rootfind)
149
+ function SymbolicContinuousCallback (eqs:: Equation , affect = NULL_AFFECT;
150
+ affect_neg = affect, rootfind = SciMLBase. LeftRootFind)
151
+ SymbolicContinuousCallback ([eqs], affect, affect_neg, rootfind)
152
+ end
153
+ function SymbolicContinuousCallback (eqs:: Vector{Equation} , affect = NULL_AFFECT;
154
+ affect_neg = affect, rootfind = SciMLBase. LeftRootFind)
155
+ SymbolicContinuousCallback (eqs, affect, affect_neg, rootfind)
156
+ end
148
157
149
158
SymbolicContinuousCallbacks (cb:: SymbolicContinuousCallback ) = [cb]
150
159
SymbolicContinuousCallbacks (cbs:: Vector{<:SymbolicContinuousCallback} ) = cbs
@@ -510,13 +519,15 @@ end
510
519
Generate a single rootfinding callback; this happens if there is only one equation in `cbs` passed to
511
520
generate_rootfinding_callback and thus we can produce a ContinuousCallback instead of a VectorContinuousCallback.
512
521
"""
513
- function generate_single_rootfinding_callback (eq, cb, sys:: AbstractODESystem , dvs = unknowns (sys),
514
- ps = full_parameters (sys); kwargs... )
522
+ function generate_single_rootfinding_callback (
523
+ eq, cb, sys:: AbstractODESystem , dvs = unknowns (sys),
524
+ ps = full_parameters (sys); kwargs... )
515
525
if ! isequal (eq. lhs, 0 )
516
526
eq = 0 ~ eq. lhs - eq. rhs
517
527
end
518
-
519
- rf_oop, rf_ip = generate_custom_function (sys, [eq. rhs], dvs, ps; expression = Val{false }, kwargs... )
528
+
529
+ rf_oop, rf_ip = generate_custom_function (
530
+ sys, [eq. rhs], dvs, ps; expression = Val{false }, kwargs... )
520
531
affect_function = compile_affect_fn (cb, sys, dvs, ps, kwargs)
521
532
cond = function (u, t, integ)
522
533
if DiffEqBase. isinplace (integ. sol. prob)
@@ -527,11 +538,13 @@ function generate_single_rootfinding_callback(eq, cb, sys::AbstractODESystem, dv
527
538
rf_oop (u, parameter_values (integ), t)
528
539
end
529
540
end
530
- return ContinuousCallback (cond, affect_function. affect, affect_function. affect_neg, rootfind= cb. rootfind)
541
+ return ContinuousCallback (
542
+ cond, affect_function. affect, affect_function. affect_neg, rootfind = cb. rootfind)
531
543
end
532
544
533
- function generate_vector_rootfinding_callback (cbs, sys:: AbstractODESystem , dvs = unknowns (sys),
534
- ps = full_parameters (sys); rootfind= SciMLBase. RightRootFind, kwargs... )
545
+ function generate_vector_rootfinding_callback (
546
+ cbs, sys:: AbstractODESystem , dvs = unknowns (sys),
547
+ ps = full_parameters (sys); rootfind = SciMLBase. RightRootFind, kwargs... )
535
548
eqs = map (cb -> flatten_equations (cb. eqs), cbs)
536
549
num_eqs = length .(eqs)
537
550
# fuse equations to create VectorContinuousCallback
@@ -543,9 +556,16 @@ function generate_vector_rootfinding_callback(cbs, sys::AbstractODESystem, dvs =
543
556
end
544
557
545
558
rhss = map (x -> x. rhs, eqs)
546
- _, rf_ip = generate_custom_function (sys, rhss, dvs, ps; expression = Val{false }, kwargs... )
547
-
548
- affect_functions = @NamedTuple {affect:: Function , affect_neg:: Union{Function, Nothing} }[compile_affect_fn (cb, sys, dvs, ps, kwargs) for cb in cbs]
559
+ _, rf_ip = generate_custom_function (
560
+ sys, rhss, dvs, ps; expression = Val{false }, kwargs... )
561
+
562
+ affect_functions = @NamedTuple {affect:: Function , affect_neg:: Union{Function, Nothing} }[compile_affect_fn (
563
+ cb,
564
+ sys,
565
+ dvs,
566
+ ps,
567
+ kwargs)
568
+ for cb in cbs]
549
569
cond = function (out, u, t, integ)
550
570
rf_ip (out, u, parameter_values (integ), t)
551
571
end
@@ -571,7 +591,8 @@ function generate_vector_rootfinding_callback(cbs, sys::AbstractODESystem, dvs =
571
591
affect_neg (integ)
572
592
end
573
593
end
574
- return VectorContinuousCallback (cond, affect, affect_neg, length (eqs), rootfind= rootfind)
594
+ return VectorContinuousCallback (
595
+ cond, affect, affect_neg, length (eqs), rootfind = rootfind)
575
596
end
576
597
577
598
"""
@@ -582,13 +603,14 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
582
603
eq_neg_aff = affect_negs (cb)
583
604
affect = compile_affect (eq_aff, sys, dvs, ps; expression = Val{false }, kwargs... )
584
605
if eq_neg_aff === eq_aff
585
- affect_neg = affect
606
+ affect_neg = affect
586
607
elseif isnothing (eq_neg_aff)
587
608
affect_neg = nothing
588
609
else
589
- affect_neg = compile_affect (eq_neg_aff, sys, dvs, ps; expression = Val{false }, kwargs... )
610
+ affect_neg = compile_affect (
611
+ eq_neg_aff, sys, dvs, ps; expression = Val{false }, kwargs... )
590
612
end
591
- (affect= affect, affect_neg= affect_neg)
613
+ (affect = affect, affect_neg = affect_neg)
592
614
end
593
615
594
616
function generate_rootfinding_callback (cbs, sys:: AbstractODESystem , dvs = unknowns (sys),
@@ -609,19 +631,24 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
609
631
610
632
# group the cbs by what rootfind op they use
611
633
# groupby would be very useful here, but alas
612
- cb_classes = Dict {@NamedTuple{rootfind::SciMLBase.RootfindOpt}, Vector{SymbolicContinuousCallback}} ()
613
- for cb in cbs
614
- push! (get! (() -> SymbolicContinuousCallback[], cb_classes, (rootfind= cb. rootfind, )), cb)
634
+ cb_classes = Dict{
635
+ @NamedTuple {rootfind:: SciMLBase.RootfindOpt }, Vector{SymbolicContinuousCallback}}()
636
+ for cb in cbs
637
+ push! (
638
+ get! (() -> SymbolicContinuousCallback[], cb_classes, (rootfind = cb. rootfind,)),
639
+ cb)
615
640
end
616
641
617
642
# generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order
618
- compiled_callbacks = map (collect (pairs (sort! (OrderedDict (cb_classes); by= p-> p. rootfind)))) do (equiv_class, cbs_in_class)
619
- return generate_vector_rootfinding_callback (cbs_in_class, sys, dvs, ps; rootfind= equiv_class. rootfind, kwargs... )
643
+ compiled_callbacks = map (collect (pairs (sort! (
644
+ OrderedDict (cb_classes); by = p -> p. rootfind)))) do (equiv_class, cbs_in_class)
645
+ return generate_vector_rootfinding_callback (
646
+ cbs_in_class, sys, dvs, ps; rootfind = equiv_class. rootfind, kwargs... )
620
647
end
621
648
if length (compiled_callbacks) == 1
622
649
return compiled_callbacks[]
623
650
else
624
- return CallbackSet (compiled_callbacks... )
651
+ return CallbackSet (compiled_callbacks... )
625
652
end
626
653
end
627
654
0 commit comments