@@ -104,28 +104,38 @@ The affect function updates the value at `x` in `modified` to be the result of e
104
104
modified:: Vector
105
105
mod_syms:: Vector{Symbol}
106
106
ctx:: Any
107
+ skip_checks:: Bool
107
108
end
108
109
109
110
function MutatingFunctionalAffect (f:: Function ;
110
111
observed:: NamedTuple = NamedTuple {()} (()),
111
112
modified:: NamedTuple = NamedTuple {()} (()),
112
- ctx = nothing )
113
- MutatingFunctionalAffect (f, collect (values (observed)), collect (keys (observed)),
114
- collect (values (modified)), collect (keys (modified)), ctx)
113
+ ctx = nothing ,
114
+ skip_checks = false )
115
+ MutatingFunctionalAffect (f,
116
+ collect (values (observed)), collect (keys (observed)),
117
+ collect (values (modified)), collect (keys (modified)),
118
+ ctx, skip_checks)
115
119
end
116
120
function MutatingFunctionalAffect (f:: Function , modified:: NamedTuple ;
117
- observed:: NamedTuple = NamedTuple {()} (()), ctx = nothing )
118
- MutatingFunctionalAffect (f, observed = observed, modified = modified, ctx = ctx)
121
+ observed:: NamedTuple = NamedTuple {()} (()), ctx = nothing , skip_checks = false )
122
+ MutatingFunctionalAffect (f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks )
119
123
end
120
124
function MutatingFunctionalAffect (
121
- f:: Function , modified:: NamedTuple , observed:: NamedTuple ; ctx = nothing )
122
- MutatingFunctionalAffect (f, observed = observed, modified = modified, ctx = ctx)
125
+ f:: Function , modified:: NamedTuple , observed:: NamedTuple ; ctx = nothing , skip_checks = false )
126
+ MutatingFunctionalAffect (f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks )
123
127
end
124
128
function MutatingFunctionalAffect (
125
- f:: Function , modified:: NamedTuple , observed:: NamedTuple , ctx)
126
- MutatingFunctionalAffect (f, observed = observed, modified = modified, ctx = ctx)
129
+ f:: Function , modified:: NamedTuple , observed:: NamedTuple , ctx; skip_checks = false )
130
+ MutatingFunctionalAffect (f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks )
127
131
end
128
132
133
+ function Base. show (io:: IO , mfa:: MutatingFunctionalAffect )
134
+ obs_vals = join (map ((ob,nm) -> " $ob => $nm " , mfa. obs, mfa. obs_syms), " , " )
135
+ mod_vals = join (map ((md,nm) -> " $md => $nm " , mfa. modified, mfa. mod_syms), " , " )
136
+ affect = mfa. f
137
+ print (io, " MutatingFunctionalAffect(observed: [$obs_vals ], modified: [$mod_vals ], affect:$affect )" )
138
+ end
129
139
func (f:: MutatingFunctionalAffect ) = f. f
130
140
context (a:: MutatingFunctionalAffect ) = a. ctx
131
141
observed (a:: MutatingFunctionalAffect ) = a. obs
@@ -208,31 +218,101 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either:
208
218
"""
209
219
struct SymbolicContinuousCallback
210
220
eqs:: Vector{Equation}
221
+ initialize:: Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect}
222
+ finalize:: Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect}
211
223
affect:: Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect}
212
224
affect_neg:: Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect, Nothing}
213
225
rootfind:: SciMLBase.RootfindOpt
214
- function SymbolicContinuousCallback (; eqs:: Vector{Equation} , affect = NULL_AFFECT,
215
- affect_neg = affect, rootfind = SciMLBase. LeftRootFind)
216
- new (eqs, make_affect (affect), make_affect (affect_neg), rootfind)
226
+ function SymbolicContinuousCallback (;
227
+ eqs:: Vector{Equation} ,
228
+ affect = NULL_AFFECT,
229
+ affect_neg = affect,
230
+ rootfind = SciMLBase. LeftRootFind,
231
+ initialize= NULL_AFFECT,
232
+ finalize= NULL_AFFECT)
233
+ new (eqs, initialize, finalize, make_affect (affect), make_affect (affect_neg), rootfind)
217
234
end # Default affect to nothing
218
235
end
219
236
make_affect (affect) = affect
220
237
make_affect (affect:: Tuple ) = FunctionalAffect (affect... )
221
238
make_affect (affect:: NamedTuple ) = FunctionalAffect (; affect... )
222
239
223
240
function Base.:(== )(e1:: SymbolicContinuousCallback , e2:: SymbolicContinuousCallback )
224
- isequal (e1. eqs, e2. eqs) && isequal (e1. affect, e2. affect) &&
241
+ isequal (e1. eqs, e2. eqs) && isequal (e1. affect, e2. affect) &&
242
+ isequal (e1. initialize, e2. initialize) && isequal (e1. finalize, e2. finalize) &&
225
243
isequal (e1. affect_neg, e2. affect_neg) && isequal (e1. rootfind, e2. rootfind)
226
244
end
227
245
Base. isempty (cb:: SymbolicContinuousCallback ) = isempty (cb. eqs)
228
246
function Base. hash (cb:: SymbolicContinuousCallback , s:: UInt )
247
+ hash_affect (affect:: AbstractVector , s) = foldr (hash, affect, init = s)
248
+ hash_affect (affect, s) = hash (cb. affect, s)
229
249
s = foldr (hash, cb. eqs, init = s)
230
- s = cb. affect isa AbstractVector ? foldr (hash, cb. affect, init = s) : hash (cb. affect, s)
231
- s = cb. affect_neg isa AbstractVector ? foldr (hash, cb. affect_neg, init = s) :
232
- hash (cb. affect_neg, s)
250
+ s = hash_affect (cb. affect, s)
251
+ s = hash_affect (cb. affect_neg, s)
252
+ s = hash_affect (cb. initialize, s)
253
+ s = hash_affect (cb. finalize, s)
233
254
hash (cb. rootfind, s)
234
255
end
235
256
257
+
258
+ function Base. show (io:: IO , cb:: SymbolicContinuousCallback )
259
+ indent = get (io, :indent , 0 )
260
+ iio = IOContext (io, :indent => indent+ 1 )
261
+ print (io, " SymbolicContinuousCallback(" )
262
+ print (iio, " Equations:" )
263
+ show (iio, equations (cb))
264
+ print (iio, " ; " )
265
+ if affects (cb) != NULL_AFFECT
266
+ print (iio, " Affect:" )
267
+ show (iio, affects (cb))
268
+ print (iio, " , " )
269
+ end
270
+ if affect_negs (cb) != NULL_AFFECT
271
+ print (iio, " Negative-edge affect:" )
272
+ show (iio, affect_negs (cb))
273
+ print (iio, " , " )
274
+ end
275
+ if initialize_affects (cb) != NULL_AFFECT
276
+ print (iio, " Initialization affect:" )
277
+ show (iio, initialize_affects (cb))
278
+ print (iio, " , " )
279
+ end
280
+ if finalize_affects (cb) != NULL_AFFECT
281
+ print (iio, " Finalization affect:" )
282
+ show (iio, finalize_affects (cb))
283
+ end
284
+ print (iio, " )" )
285
+ end
286
+
287
+ function Base. show (io:: IO , mime:: MIME"text/plain" , cb:: SymbolicContinuousCallback )
288
+ indent = get (io, :indent , 0 )
289
+ iio = IOContext (io, :indent => indent+ 1 )
290
+ println (io, " SymbolicContinuousCallback:" )
291
+ println (iio, " Equations:" )
292
+ show (iio, mime, equations (cb))
293
+ print (iio, " \n " )
294
+ if affects (cb) != NULL_AFFECT
295
+ println (iio, " Affect:" )
296
+ show (iio, mime, affects (cb))
297
+ print (iio, " \n " )
298
+ end
299
+ if affect_negs (cb) != NULL_AFFECT
300
+ println (iio, " Negative-edge affect:" )
301
+ show (iio, mime, affect_negs (cb))
302
+ print (iio, " \n " )
303
+ end
304
+ if initialize_affects (cb) != NULL_AFFECT
305
+ println (iio, " Initialization affect:" )
306
+ show (iio, mime, initialize_affects (cb))
307
+ print (iio, " \n " )
308
+ end
309
+ if finalize_affects (cb) != NULL_AFFECT
310
+ println (iio, " Finalization affect:" )
311
+ show (iio, mime, finalize_affects (cb))
312
+ print (iio, " \n " )
313
+ end
314
+ end
315
+
236
316
to_equation_vector (eq:: Equation ) = [eq]
237
317
to_equation_vector (eqs:: Vector{Equation} ) = eqs
238
318
function to_equation_vector (eqs:: Vector{Any} )
@@ -246,14 +326,14 @@ end # wrap eq in vector
246
326
SymbolicContinuousCallback (p:: Pair ) = SymbolicContinuousCallback (p[1 ], p[2 ])
247
327
SymbolicContinuousCallback (cb:: SymbolicContinuousCallback ) = cb # passthrough
248
328
function SymbolicContinuousCallback (eqs:: Equation , affect = NULL_AFFECT;
249
- affect_neg = affect, rootfind = SciMLBase. LeftRootFind)
329
+ affect_neg = affect, rootfind = SciMLBase. LeftRootFind, initialize = NULL_AFFECT, finalize = NULL_AFFECT )
250
330
SymbolicContinuousCallback (
251
- eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind)
331
+ eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind, initialize = initialize, finalize = finalize )
252
332
end
253
333
function SymbolicContinuousCallback (eqs:: Vector{Equation} , affect = NULL_AFFECT;
254
- affect_neg = affect, rootfind = SciMLBase. LeftRootFind)
334
+ affect_neg = affect, rootfind = SciMLBase. LeftRootFind, initialize = NULL_AFFECT, finalize = NULL_AFFECT )
255
335
SymbolicContinuousCallback (
256
- eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind)
336
+ eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind, initialize = initialize, finalize = finalize )
257
337
end
258
338
259
339
SymbolicContinuousCallbacks (cb:: SymbolicContinuousCallback ) = [cb]
@@ -282,6 +362,16 @@ function affect_negs(cbs::Vector{SymbolicContinuousCallback})
282
362
mapreduce (affect_negs, vcat, cbs, init = Equation[])
283
363
end
284
364
365
+ initialize_affects (cb:: SymbolicContinuousCallback ) = cb. initialize
366
+ function initialize_affects (cbs:: Vector{SymbolicContinuousCallback} )
367
+ mapreduce (initialize_affects, vcat, cbs, init = Equation[])
368
+ end
369
+
370
+ finalize_affects (cb:: SymbolicContinuousCallback ) = cb. initialize
371
+ function finalize_affects (cbs:: Vector{SymbolicContinuousCallback} )
372
+ mapreduce (finalize_affects, vcat, cbs, init = Equation[])
373
+ end
374
+
285
375
namespace_affects (af:: Vector , s) = Equation[namespace_affect (a, s) for a in af]
286
376
namespace_affects (af:: FunctionalAffect , s) = namespace_affect (af, s)
287
377
namespace_affects (af:: MutatingFunctionalAffect , s) = namespace_affect (af, s)
@@ -292,6 +382,8 @@ function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuo
292
382
eqs = namespace_equation .(equations (cb), (s,)),
293
383
affect = namespace_affects (affects (cb), s),
294
384
affect_neg = namespace_affects (affect_negs (cb), s),
385
+ initialize = namespace_affects (initialize_affects (cb), s),
386
+ finalize = namespace_affects (finalize_affects (cb), s),
295
387
rootfind = cb. rootfind)
296
388
end
297
389
@@ -681,8 +773,9 @@ function generate_single_rootfinding_callback(
681
773
initfn = SciMLBase. INITIALIZE_DEFAULT
682
774
end
683
775
return ContinuousCallback (
684
- cond, affect_function. affect, affect_function. affect_neg,
685
- rootfind = cb. rootfind, initialize = initfn)
776
+ cond, affect_function. affect, affect_function. affect_neg, rootfind = cb. rootfind,
777
+ initialize = isnothing (affect_function. initialize) ? SciMLBase. INITIALIZE_DEFAULT : (c, u, t, i) -> affect_function. initialize (i),
778
+ finalize = isnothing (affect_function. finalize) ? SciMLBase. FINALIZE_DEFAULT : (c, u, t, i) -> affect_function. finalize (i))
686
779
end
687
780
688
781
function generate_vector_rootfinding_callback (
@@ -702,13 +795,12 @@ function generate_vector_rootfinding_callback(
702
795
_, rf_ip = generate_custom_function (
703
796
sys, rhss, dvs, ps; expression = Val{false }, kwargs... )
704
797
705
- affect_functions = @NamedTuple {affect:: Function , affect_neg:: Union{Function, Nothing} }[compile_affect_fn (
706
- cb,
707
- sys,
708
- dvs,
709
- ps,
710
- kwargs)
711
- for cb in cbs]
798
+ affect_functions = @NamedTuple {
799
+ affect:: Function ,
800
+ affect_neg:: Union{Function, Nothing} ,
801
+ initialize:: Union{Function, Nothing} ,
802
+ finalize:: Union{Function, Nothing} }[
803
+ compile_affect_fn (cb, sys, dvs, ps, kwargs) for cb in cbs]
712
804
cond = function (out, u, t, integ)
713
805
rf_ip (out, u, parameter_values (integ), t)
714
806
end
@@ -734,25 +826,27 @@ function generate_vector_rootfinding_callback(
734
826
affect_neg (integ)
735
827
end
736
828
end
737
- if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing
738
- save_idxs = mapreduce (
739
- cb -> get (ic. callback_to_clocks, cb, Int[]), vcat, cbs; init = Int[])
740
- initfn = if isempty (save_idxs)
741
- SciMLBase. INITIALIZE_DEFAULT
829
+ function handle_optional_setup_fn (funs, default)
830
+ if all (isnothing, funs)
831
+ return default
742
832
else
743
- let save_idxs = save_idxs
744
- function (cb, u, t, integrator)
745
- for idx in save_idxs
746
- SciMLBase. save_discretes! (integrator, idx)
833
+ return let funs = funs
834
+ function (cb, u, t, integ)
835
+ for func in funs
836
+ if isnothing (func)
837
+ continue
838
+ else
839
+ func (integ)
840
+ end
747
841
end
748
842
end
749
843
end
750
844
end
751
- else
752
- initfn = SciMLBase. INITIALIZE_DEFAULT
753
845
end
846
+ initialize = handle_optional_setup_fn (map (fn -> fn. initialize, affect_functions), SciMLBase. INITIALIZE_DEFAULT)
847
+ finalize = handle_optional_setup_fn (map (fn -> fn. finalize, affect_functions), SciMLBase. FINALIZE_DEFAULT)
754
848
return VectorContinuousCallback (
755
- cond, affect, affect_neg, length (eqs), rootfind = rootfind, initialize = initfn )
849
+ cond, affect, affect_neg, length (eqs), rootfind = rootfind, initialize = initialize, finalize = finalize )
756
850
end
757
851
758
852
"""
@@ -762,15 +856,23 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
762
856
eq_aff = affects (cb)
763
857
eq_neg_aff = affect_negs (cb)
764
858
affect = compile_affect (eq_aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
859
+ function compile_optional_affect (aff)
860
+ if isnothing (aff)
861
+ return nothing
862
+ else
863
+ affspr = compile_affect (aff, cb, sys, dvs, ps; expression = Val{true }, kwargs... )
864
+ @show affspr
865
+ return compile_affect (aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
866
+ end
867
+ end
765
868
if eq_neg_aff === eq_aff
766
869
affect_neg = affect
767
- elseif isnothing (eq_neg_aff)
768
- affect_neg = nothing
769
870
else
770
- affect_neg = compile_affect (
771
- eq_neg_aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
871
+ affect_neg = compile_optional_affect (eq_neg_aff)
772
872
end
773
- (affect = affect, affect_neg = affect_neg)
873
+ initialize = compile_optional_affect (initialize_affects (cb))
874
+ finalize = compile_optional_affect (finalize_affects (cb))
875
+ (affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize)
774
876
end
775
877
776
878
function generate_rootfinding_callback (cbs, sys:: AbstractODESystem , dvs = unknowns (sys),
@@ -877,7 +979,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
877
979
push! (syms_dedup, sym)
878
980
push! (exprs_dedup, exp)
879
981
push! (seen, sym)
880
- else
982
+ elseif ! affect . skip_checks
881
983
@warn " Expression $(expr) is aliased as $sym , which has already been used. The first definition will be used."
882
984
end
883
985
end
@@ -887,7 +989,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
887
989
obs_exprs = observed (affect)
888
990
for oexpr in obs_exprs
889
991
invalid_vars = invalid_variables (sys, oexpr)
890
- if length (invalid_vars) > 0
992
+ if length (invalid_vars) > 0 && ! affect . skip_checks
891
993
error (" Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars) ; the variables may not have been added (e.g. if a component is missing)." )
892
994
end
893
995
end
@@ -897,11 +999,11 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
897
999
898
1000
mod_exprs = modified (affect)
899
1001
for mexpr in mod_exprs
900
- if ! is_observed (sys, mexpr) && parameter_index (sys, mexpr) === nothing
901
- error (" Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect." )
1002
+ if ! is_observed (sys, mexpr) && parameter_index (sys, mexpr) === nothing && ! affect . skip_checks
1003
+ @warn (" Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect." )
902
1004
end
903
1005
invalid_vars = unassignable_variables (sys, mexpr)
904
- if length (invalid_vars) > 0
1006
+ if length (invalid_vars) > 0 && ! affect . skip_checks
905
1007
error (" Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars) ; the variables may not have been added (e.g. if a component is missing) or they may have been reduced away." )
906
1008
end
907
1009
end
@@ -911,7 +1013,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa
911
1013
sys, mod_exprs; return_inplace = true )
912
1014
913
1015
overlapping_syms = intersect (mod_syms, obs_syms)
914
- if length (overlapping_syms) > 0
1016
+ if length (overlapping_syms) > 0 && ! affect . skip_checks
915
1017
@warn " The symbols $overlapping_syms are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value."
916
1018
end
917
1019
0 commit comments