@@ -116,29 +116,33 @@ function ImperativeAffect(f::Function;
116
116
modified:: NamedTuple = NamedTuple {()} (()),
117
117
ctx = nothing ,
118
118
skip_checks = false )
119
- ImperativeAffect (f,
119
+ ImperativeAffect (f,
120
120
collect (values (observed)), collect (keys (observed)),
121
- collect (values (modified)), collect (keys (modified)),
121
+ collect (values (modified)), collect (keys (modified)),
122
122
ctx, skip_checks)
123
123
end
124
124
function ImperativeAffect (f:: Function , modified:: NamedTuple ;
125
- observed:: NamedTuple = NamedTuple {()} (()), ctx = nothing , skip_checks= false )
126
- ImperativeAffect (f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
125
+ observed:: NamedTuple = NamedTuple {()} (()), ctx = nothing , skip_checks = false )
126
+ ImperativeAffect (
127
+ f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
127
128
end
128
129
function ImperativeAffect (
129
- f:: Function , modified:: NamedTuple , observed:: NamedTuple ; ctx = nothing , skip_checks= false )
130
- ImperativeAffect (f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
130
+ f:: Function , modified:: NamedTuple , observed:: NamedTuple ; ctx = nothing , skip_checks = false )
131
+ ImperativeAffect (
132
+ f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
131
133
end
132
134
function ImperativeAffect (
133
- f:: Function , modified:: NamedTuple , observed:: NamedTuple , ctx; skip_checks= false )
134
- ImperativeAffect (f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
135
+ f:: Function , modified:: NamedTuple , observed:: NamedTuple , ctx; skip_checks = false )
136
+ ImperativeAffect (
137
+ f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
135
138
end
136
139
137
- function Base. show (io:: IO , mfa:: ImperativeAffect )
138
- obs_vals = join (map ((ob,nm) -> " $ob => $nm " , mfa. obs, mfa. obs_syms), " , " )
139
- mod_vals = join (map ((md,nm) -> " $md => $nm " , mfa. modified, mfa. mod_syms), " , " )
140
+ function Base. show (io:: IO , mfa:: ImperativeAffect )
141
+ obs_vals = join (map ((ob, nm) -> " $ob => $nm " , mfa. obs, mfa. obs_syms), " , " )
142
+ mod_vals = join (map ((md, nm) -> " $md => $nm " , mfa. modified, mfa. mod_syms), " , " )
140
143
affect = mfa. f
141
- print (io, " ImperativeAffect(observed: [$obs_vals ], modified: [$mod_vals ], affect:$affect )" )
144
+ print (io,
145
+ " ImperativeAffect(observed: [$obs_vals ], modified: [$mod_vals ], affect:$affect )" )
142
146
end
143
147
func (f:: ImperativeAffect ) = f. f
144
148
context (a:: ImperativeAffect ) = a. ctx
@@ -234,24 +238,25 @@ struct SymbolicContinuousCallback
234
238
affect_neg:: Union{Vector{Equation}, FunctionalAffect, ImperativeAffect, Nothing}
235
239
rootfind:: SciMLBase.RootfindOpt
236
240
reinitializealg:: SciMLBase.DAEInitializationAlgorithm
237
- function SymbolicContinuousCallback (;
238
- eqs:: Vector{Equation} ,
239
- affect = NULL_AFFECT,
240
- affect_neg = affect,
241
- rootfind = SciMLBase. LeftRootFind,
242
- initialize= NULL_AFFECT,
243
- finalize= NULL_AFFECT,
244
- reinitializealg= SciMLBase. CheckInit ())
245
- new (eqs, initialize, finalize, make_affect (affect), make_affect (affect_neg), rootfind, reinitializealg)
241
+ function SymbolicContinuousCallback (;
242
+ eqs:: Vector{Equation} ,
243
+ affect = NULL_AFFECT,
244
+ affect_neg = affect,
245
+ rootfind = SciMLBase. LeftRootFind,
246
+ initialize = NULL_AFFECT,
247
+ finalize = NULL_AFFECT,
248
+ reinitializealg = SciMLBase. CheckInit ())
249
+ new (eqs, initialize, finalize, make_affect (affect),
250
+ make_affect (affect_neg), rootfind, reinitializealg)
246
251
end # Default affect to nothing
247
252
end
248
253
make_affect (affect) = affect
249
254
make_affect (affect:: Tuple ) = FunctionalAffect (affect... )
250
255
make_affect (affect:: NamedTuple ) = FunctionalAffect (; affect... )
251
256
252
257
function Base.:(== )(e1:: SymbolicContinuousCallback , e2:: SymbolicContinuousCallback )
253
- isequal (e1. eqs, e2. eqs) && isequal (e1. affect, e2. affect) &&
254
- isequal (e1. initialize, e2. initialize) && isequal (e1. finalize, e2. finalize) &&
258
+ isequal (e1. eqs, e2. eqs) && isequal (e1. affect, e2. affect) &&
259
+ isequal (e1. initialize, e2. initialize) && isequal (e1. finalize, e2. finalize) &&
255
260
isequal (e1. affect_neg, e2. affect_neg) && isequal (e1. rootfind, e2. rootfind)
256
261
end
257
262
Base. isempty (cb:: SymbolicContinuousCallback ) = isempty (cb. eqs)
@@ -266,10 +271,9 @@ function Base.hash(cb::SymbolicContinuousCallback, s::UInt)
266
271
hash (cb. rootfind, s)
267
272
end
268
273
269
-
270
274
function Base. show (io:: IO , cb:: SymbolicContinuousCallback )
271
275
indent = get (io, :indent , 0 )
272
- iio = IOContext (io, :indent => indent+ 1 )
276
+ iio = IOContext (io, :indent => indent + 1 )
273
277
print (io, " SymbolicContinuousCallback(" )
274
278
print (iio, " Equations:" )
275
279
show (iio, equations (cb))
298
302
299
303
function Base. show (io:: IO , mime:: MIME"text/plain" , cb:: SymbolicContinuousCallback )
300
304
indent = get (io, :indent , 0 )
301
- iio = IOContext (io, :indent => indent+ 1 )
305
+ iio = IOContext (io, :indent => indent + 1 )
302
306
println (io, " SymbolicContinuousCallback:" )
303
307
println (iio, " Equations:" )
304
308
show (iio, mime, equations (cb))
@@ -338,14 +342,18 @@ end # wrap eq in vector
338
342
SymbolicContinuousCallback (p:: Pair ) = SymbolicContinuousCallback (p[1 ], p[2 ])
339
343
SymbolicContinuousCallback (cb:: SymbolicContinuousCallback ) = cb # passthrough
340
344
function SymbolicContinuousCallback (eqs:: Equation , affect = NULL_AFFECT;
341
- affect_neg = affect, rootfind = SciMLBase. LeftRootFind, initialize = NULL_AFFECT, finalize = NULL_AFFECT)
345
+ affect_neg = affect, rootfind = SciMLBase. LeftRootFind,
346
+ initialize = NULL_AFFECT, finalize = NULL_AFFECT)
342
347
SymbolicContinuousCallback (
343
- eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind, initialize= initialize, finalize= finalize)
348
+ eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind,
349
+ initialize = initialize, finalize = finalize)
344
350
end
345
351
function SymbolicContinuousCallback (eqs:: Vector{Equation} , affect = NULL_AFFECT;
346
- affect_neg = affect, rootfind = SciMLBase. LeftRootFind, initialize = NULL_AFFECT, finalize = NULL_AFFECT)
352
+ affect_neg = affect, rootfind = SciMLBase. LeftRootFind,
353
+ initialize = NULL_AFFECT, finalize = NULL_AFFECT)
347
354
SymbolicContinuousCallback (
348
- eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind, initialize= initialize, finalize= finalize)
355
+ eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind,
356
+ initialize = initialize, finalize = finalize)
349
357
end
350
358
351
359
SymbolicContinuousCallbacks (cb:: SymbolicContinuousCallback ) = [cb]
@@ -385,8 +393,10 @@ function finalize_affects(cbs::Vector{SymbolicContinuousCallback})
385
393
end
386
394
387
395
reinitialization_alg (cb:: SymbolicContinuousCallback ) = cb. reinitializealg
388
- reinitialization_algs (cbs:: Vector{SymbolicContinuousCallback} ) =
389
- mapreduce (reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
396
+ function reinitialization_algs (cbs:: Vector{SymbolicContinuousCallback} )
397
+ mapreduce (
398
+ reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
399
+ end
390
400
391
401
namespace_affects (af:: Vector , s) = Equation[namespace_affect (a, s) for a in af]
392
402
namespace_affects (af:: FunctionalAffect , s) = namespace_affect (af, s)
@@ -436,7 +446,8 @@ struct SymbolicDiscreteCallback
436
446
affects:: Any
437
447
reinitializealg:: SciMLBase.DAEInitializationAlgorithm
438
448
439
- function SymbolicDiscreteCallback (condition, affects = NULL_AFFECT, reinitializealg= SciMLBase. CheckInit ())
449
+ function SymbolicDiscreteCallback (
450
+ condition, affects = NULL_AFFECT, reinitializealg = SciMLBase. CheckInit ())
440
451
c = scalarize_condition (condition)
441
452
a = scalarize_affects (affects)
442
453
new (c, a, reinitializealg)
@@ -498,8 +509,10 @@ function affects(cbs::Vector{SymbolicDiscreteCallback})
498
509
end
499
510
500
511
reinitialization_alg (cb:: SymbolicDiscreteCallback ) = cb. reinitializealg
501
- reinitialization_algs (cbs:: Vector{SymbolicDiscreteCallback} ) =
502
- mapreduce (reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
512
+ function reinitialization_algs (cbs:: Vector{SymbolicDiscreteCallback} )
513
+ mapreduce (
514
+ reinitialization_alg, vcat, cbs, init = SciMLBase. DAEInitializationAlgorithm[])
515
+ end
503
516
504
517
function namespace_callback (cb:: SymbolicDiscreteCallback , s):: SymbolicDiscreteCallback
505
518
af = affects (cb)
@@ -781,7 +794,8 @@ function generate_single_rootfinding_callback(
781
794
end
782
795
end
783
796
784
- user_initfun = isnothing (affect_function. initialize) ? SciMLBase. INITIALIZE_DEFAULT : (c, u, t, i) -> affect_function. initialize (i)
797
+ user_initfun = isnothing (affect_function. initialize) ? SciMLBase. INITIALIZE_DEFAULT :
798
+ (c, u, t, i) -> affect_function. initialize (i)
785
799
if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing &&
786
800
(save_idxs = get (ic. callback_to_clocks, cb, nothing )) != = nothing
787
801
initfn = let save_idxs = save_idxs
@@ -795,17 +809,19 @@ function generate_single_rootfinding_callback(
795
809
else
796
810
initfn = user_initfun
797
811
end
798
-
812
+
799
813
return ContinuousCallback (
800
- cond, affect_function. affect, affect_function. affect_neg, rootfind = cb. rootfind,
801
- initialize = initfn,
802
- finalize = isnothing (affect_function. finalize) ? SciMLBase. FINALIZE_DEFAULT : (c, u, t, i) -> affect_function. finalize (i),
814
+ cond, affect_function. affect, affect_function. affect_neg, rootfind = cb. rootfind,
815
+ initialize = initfn,
816
+ finalize = isnothing (affect_function. finalize) ? SciMLBase. FINALIZE_DEFAULT :
817
+ (c, u, t, i) -> affect_function. finalize (i),
803
818
initializealg = reinitialization_alg (cb))
804
819
end
805
820
806
821
function generate_vector_rootfinding_callback (
807
822
cbs, sys:: AbstractODESystem , dvs = unknowns (sys),
808
- ps = parameters (sys); rootfind = SciMLBase. RightRootFind, reinitialization = SciMLBase. CheckInit (), kwargs... )
823
+ ps = parameters (sys); rootfind = SciMLBase. RightRootFind,
824
+ reinitialization = SciMLBase. CheckInit (), kwargs... )
809
825
eqs = map (cb -> flatten_equations (cb. eqs), cbs)
810
826
num_eqs = length .(eqs)
811
827
# fuse equations to create VectorContinuousCallback
@@ -821,11 +837,12 @@ function generate_vector_rootfinding_callback(
821
837
sys, rhss, dvs, ps; expression = Val{false }, kwargs... )
822
838
823
839
affect_functions = @NamedTuple {
824
- affect:: Function ,
825
- affect_neg:: Union{Function, Nothing} ,
826
- initialize:: Union{Function, Nothing} ,
840
+ affect:: Function ,
841
+ affect_neg:: Union{Function, Nothing} ,
842
+ initialize:: Union{Function, Nothing} ,
827
843
finalize:: Union{Function, Nothing} }[
828
- compile_affect_fn (cb, sys, dvs, ps, kwargs) for cb in cbs]
844
+ compile_affect_fn (cb, sys, dvs, ps, kwargs)
845
+ for cb in cbs]
829
846
cond = function (out, u, t, integ)
830
847
rf_ip (out, u, parameter_values (integ), t)
831
848
end
@@ -861,17 +878,20 @@ function generate_vector_rootfinding_callback(
861
878
if isnothing (func)
862
879
continue
863
880
else
864
- func (integ)
881
+ func (integ)
865
882
end
866
883
end
867
884
end
868
885
end
869
886
end
870
887
end
871
- initialize = handle_optional_setup_fn (map (fn -> fn. initialize, affect_functions), SciMLBase. INITIALIZE_DEFAULT)
872
- finalize = handle_optional_setup_fn (map (fn -> fn. finalize, affect_functions), SciMLBase. FINALIZE_DEFAULT)
888
+ initialize = handle_optional_setup_fn (
889
+ map (fn -> fn. initialize, affect_functions), SciMLBase. INITIALIZE_DEFAULT)
890
+ finalize = handle_optional_setup_fn (
891
+ map (fn -> fn. finalize, affect_functions), SciMLBase. FINALIZE_DEFAULT)
873
892
return VectorContinuousCallback (
874
- cond, affect, affect_neg, length (eqs), rootfind = rootfind, initialize = initialize, finalize = finalize, initializealg = reinitialization)
893
+ cond, affect, affect_neg, length (eqs), rootfind = rootfind, initialize = initialize,
894
+ finalize = finalize, initializealg = reinitialization)
875
895
end
876
896
877
897
"""
@@ -881,8 +901,8 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
881
901
eq_aff = affects (cb)
882
902
eq_neg_aff = affect_negs (cb)
883
903
affect = compile_affect (eq_aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
884
- function compile_optional_affect (aff, default= nothing )
885
- if isnothing (aff) || aff== default
904
+ function compile_optional_affect (aff, default = nothing )
905
+ if isnothing (aff) || aff == default
886
906
return nothing
887
907
else
888
908
return compile_affect (aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
@@ -918,21 +938,23 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
918
938
# groupby would be very useful here, but alas
919
939
cb_classes = Dict{
920
940
@NamedTuple {
921
- rootfind:: SciMLBase.RootfindOpt ,
941
+ rootfind:: SciMLBase.RootfindOpt ,
922
942
reinitialization:: SciMLBase.DAEInitializationAlgorithm }, Vector{SymbolicContinuousCallback}}()
923
943
for cb in cbs
924
944
push! (
925
- get! (() -> SymbolicContinuousCallback[], cb_classes, (
926
- rootfind = cb. rootfind,
927
- reinitialization = reinitialization_alg (cb))),
945
+ get! (() -> SymbolicContinuousCallback[], cb_classes,
946
+ (
947
+ rootfind = cb. rootfind,
948
+ reinitialization = reinitialization_alg (cb))),
928
949
cb)
929
950
end
930
951
931
952
# generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order
932
953
compiled_callbacks = map (collect (pairs (sort! (
933
954
OrderedDict (cb_classes); by = p -> p. rootfind)))) do (equiv_class, cbs_in_class)
934
955
return generate_vector_rootfinding_callback (
935
- cbs_in_class, sys, dvs, ps; rootfind = equiv_class. rootfind, reinitialization= equiv_class. reinitialization, kwargs... )
956
+ cbs_in_class, sys, dvs, ps; rootfind = equiv_class. rootfind,
957
+ reinitialization = equiv_class. reinitialization, kwargs... )
936
958
end
937
959
if length (compiled_callbacks) == 1
938
960
return compiled_callbacks[]
@@ -984,29 +1006,34 @@ function invalid_variables(sys, expr)
984
1006
filter (x -> ! any (isequal (x), all_symbols (sys)), reduce (vcat, vars (expr); init = []))
985
1007
end
986
1008
function unassignable_variables (sys, expr)
987
- assignable_syms = reduce (vcat, Symbolics. scalarize .(vcat (unknowns (sys), parameters (sys))); init= [])
1009
+ assignable_syms = reduce (
1010
+ vcat, Symbolics. scalarize .(vcat (unknowns (sys), parameters (sys))); init = [])
988
1011
written = reduce (vcat, Symbolics. scalarize .(vars (expr)); init = [])
989
1012
return filter (
990
1013
x -> ! any (isequal (x), assignable_syms), written)
991
1014
end
992
1015
993
- @generated function _generated_writeback (integ, setters:: NamedTuple{NS1,<:Tuple} , values:: NamedTuple{NS2, <:Tuple} ) where {NS1, NS2}
1016
+ @generated function _generated_writeback (integ, setters:: NamedTuple{NS1, <:Tuple} ,
1017
+ values:: NamedTuple{NS2, <:Tuple} ) where {NS1, NS2}
994
1018
setter_exprs = []
995
- for name in NS2
1019
+ for name in NS2
996
1020
if ! (name in NS1)
997
1021
missing_name = " Tried to write back to $name from affect; only declared states ($NS1 ) may be written to."
998
1022
error (missing_name)
999
1023
end
1000
1024
push! (setter_exprs, :(setters.$ name (integ, values.$ name)))
1001
1025
end
1002
- return :(begin $ (setter_exprs... ) end )
1026
+ return :(begin
1027
+ $ (setter_exprs... )
1028
+ end )
1003
1029
end
1004
1030
1005
1031
function check_assignable (sys, sym)
1006
1032
if symbolic_type (sym) == ScalarSymbolic ()
1007
1033
is_variable (sys, sym) || is_parameter (sys, sym)
1008
1034
elseif symbolic_type (sym) == ArraySymbolic ()
1009
- is_variable (sys, sym) || is_parameter (sys, sym) || all (x -> check_assignable (sys, x), collect (sym))
1035
+ is_variable (sys, sym) || is_parameter (sys, sym) ||
1036
+ all (x -> check_assignable (sys, x), collect (sym))
1010
1037
elseif sym isa Union{AbstractArray, Tuple}
1011
1038
all (x -> check_assignable (sys, x), sym)
1012
1039
else
@@ -1084,13 +1111,13 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs.
1084
1111
1085
1112
# okay so now to generate the stuff to assign it back into the system
1086
1113
mod_pairs = mod_exprs .=> mod_syms
1087
- mod_names = (mod_syms... , )
1114
+ mod_names = (mod_syms... ,)
1088
1115
mod_og_val_fun = build_explicit_observed_function (
1089
1116
sys, Symbolics. scalarize .(first .(mod_pairs));
1090
1117
array_type = :tuple )
1091
1118
1092
1119
upd_funs = NamedTuple {mod_names} ((setu .((sys,), first .(mod_pairs))... ,))
1093
-
1120
+
1094
1121
if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing
1095
1122
save_idxs = get (ic. callback_to_clocks, cb, Int[])
1096
1123
else
@@ -1104,10 +1131,12 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs.
1104
1131
upd_component_array = NamedTuple {mod_names} (modvals)
1105
1132
1106
1133
# update the observed values
1107
- obs_component_array = NamedTuple {obs_sym_tuple} (obs_fun (integ. u, integ. p, integ. t))
1134
+ obs_component_array = NamedTuple {obs_sym_tuple} (obs_fun (
1135
+ integ. u, integ. p, integ. t))
1108
1136
1109
1137
# let the user do their thing
1110
- modvals = if applicable (user_affect, upd_component_array, obs_component_array, ctx, integ)
1138
+ modvals = if applicable (
1139
+ user_affect, upd_component_array, obs_component_array, ctx, integ)
1111
1140
user_affect (upd_component_array, obs_component_array, ctx, integ)
1112
1141
elseif applicable (user_affect, upd_component_array, obs_component_array, ctx)
1113
1142
user_affect (upd_component_array, obs_component_array, ctx)
@@ -1122,15 +1151,16 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs.
1122
1151
1123
1152
# write the new values back to the integrator
1124
1153
_generated_writeback (integ, upd_funs, modvals)
1125
-
1154
+
1126
1155
for idx in save_idxs
1127
1156
SciMLBase. save_discretes! (integ, idx)
1128
1157
end
1129
1158
end
1130
1159
end
1131
1160
end
1132
1161
1133
- function compile_affect (affect:: Union{FunctionalAffect, ImperativeAffect} , cb, sys, dvs, ps; kwargs... )
1162
+ function compile_affect (
1163
+ affect:: Union{FunctionalAffect, ImperativeAffect} , cb, sys, dvs, ps; kwargs... )
1134
1164
compile_user_affect (affect, cb, sys, dvs, ps; kwargs... )
1135
1165
end
1136
1166
0 commit comments