@@ -387,6 +387,27 @@ function condition_header(sys::AbstractSystem, integrator = gensym(:MTKIntegrato
387
387
end
388
388
end
389
389
390
+ function callback_save_header (sys:: AbstractSystem , cb)
391
+ if ! (has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing )
392
+ return (identity, identity)
393
+ end
394
+ save_idxs = get (ic. callback_to_clocks, cb, Int[])
395
+ isempty (save_idxs) && return (identity, identity)
396
+
397
+ wrapper = function (expr)
398
+ return Func (expr. args, [],
399
+ LiteralExpr (quote
400
+ $ (expr. body)
401
+ save_idxs = $ (save_idxs)
402
+ for idx in save_idxs
403
+ $ (SciMLBase. save_discretes!)($ (expr. args[1 ]), idx)
404
+ end
405
+ end ))
406
+ end
407
+
408
+ return wrapper, wrapper
409
+ end
410
+
390
411
"""
391
412
compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps; expression, kwargs...)
392
413
@@ -421,7 +442,7 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
421
442
end
422
443
423
444
function compile_affect (cb:: SymbolicContinuousCallback , args... ; kwargs... )
424
- compile_affect (affects (cb), args... ; kwargs... )
445
+ compile_affect (affects (cb), cb, args... ; kwargs... )
425
446
end
426
447
427
448
"""
@@ -441,7 +462,7 @@ Notes
441
462
well-formed.
442
463
- `kwargs` are passed through to `Symbolics.build_function`.
443
464
"""
444
- function compile_affect (eqs:: Vector{Equation} , sys, dvs, ps; outputidxs = nothing ,
465
+ function compile_affect (eqs:: Vector{Equation} , cb, sys, dvs, ps; outputidxs = nothing ,
445
466
expression = Val{true }, checkvars = true , eval_expression = false ,
446
467
eval_module = @__MODULE__ ,
447
468
postprocess_affect_expr! = nothing , kwargs... )
@@ -497,7 +518,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
497
518
integ = gensym (:MTKIntegrator )
498
519
pre = get_preprocess_constants (rhss)
499
520
rf_oop, rf_ip = build_function (rhss, u, p... , t; expression = Val{true },
500
- wrap_code = add_integrator_header (sys, integ, outvar) .∘
521
+ wrap_code = callback_save_header (sys, cb) .∘
522
+ add_integrator_header (sys, integ, outvar) .∘
501
523
wrap_array_vars (sys, rhss; dvs, ps = _ps) .∘
502
524
wrap_parameter_dependencies (sys, false ),
503
525
outputidxs = update_inds,
@@ -606,14 +628,14 @@ Compile a single continuous callback affect function(s).
606
628
function compile_affect_fn (cb, sys:: AbstractODESystem , dvs, ps, kwargs)
607
629
eq_aff = affects (cb)
608
630
eq_neg_aff = affect_negs (cb)
609
- affect = compile_affect (eq_aff, sys, dvs, ps; expression = Val{false }, kwargs... )
631
+ affect = compile_affect (eq_aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
610
632
if eq_neg_aff === eq_aff
611
633
affect_neg = affect
612
634
elseif isnothing (eq_neg_aff)
613
635
affect_neg = nothing
614
636
else
615
637
affect_neg = compile_affect (
616
- eq_neg_aff, sys, dvs, ps; expression = Val{false }, kwargs... )
638
+ eq_neg_aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
617
639
end
618
640
(affect = affect, affect_neg = affect_neg)
619
641
end
@@ -657,7 +679,7 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
657
679
end
658
680
end
659
681
660
- function compile_user_affect (affect:: FunctionalAffect , sys, dvs, ps; kwargs... )
682
+ function compile_user_affect (affect:: FunctionalAffect , cb, sys, dvs, ps; kwargs... )
661
683
dvs_ind = Dict (reverse (en) for en in enumerate (dvs))
662
684
v_inds = map (sym -> dvs_ind[sym], unknowns (affect))
663
685
@@ -679,21 +701,31 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
679
701
p = filter (x -> ! isnothing (x[2 ]), collect (zip (parameters_syms (affect), p_inds))) |>
680
702
NamedTuple
681
703
682
- let u = u, p = p, user_affect = func (affect), ctx = context (affect)
704
+ if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing
705
+ save_idxs = get (ic. callback_to_clocks, cb, Int[])
706
+ else
707
+ save_idxs = Int[]
708
+ end
709
+ let u = u, p = p, user_affect = func (affect), ctx = context (affect),
710
+ save_idxs = save_idxs
711
+
683
712
function (integ)
684
713
user_affect (integ, u, p, ctx)
714
+ for idx in save_idxs
715
+ SciMLBase. save_discretes! (integ, idx)
716
+ end
685
717
end
686
718
end
687
719
end
688
720
689
- function compile_affect (affect:: FunctionalAffect , sys, dvs, ps; kwargs... )
690
- compile_user_affect (affect, sys, dvs, ps; kwargs... )
721
+ function compile_affect (affect:: FunctionalAffect , cb, sys, dvs, ps; kwargs... )
722
+ compile_user_affect (affect, cb, sys, dvs, ps; kwargs... )
691
723
end
692
724
693
725
function generate_timed_callback (cb, sys, dvs, ps; postprocess_affect_expr! = nothing ,
694
726
kwargs... )
695
727
cond = condition (cb)
696
- as = compile_affect (affects (cb), sys, dvs, ps; expression = Val{false },
728
+ as = compile_affect (affects (cb), cb, sys, dvs, ps; expression = Val{false },
697
729
postprocess_affect_expr!, kwargs... )
698
730
if cond isa AbstractVector
699
731
# Preset Time
@@ -711,7 +743,7 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
711
743
kwargs... )
712
744
else
713
745
c = compile_condition (cb, sys, dvs, ps; expression = Val{false }, kwargs... )
714
- as = compile_affect (affects (cb), sys, dvs, ps; expression = Val{false },
746
+ as = compile_affect (affects (cb), cb, sys, dvs, ps; expression = Val{false },
715
747
postprocess_affect_expr!, kwargs... )
716
748
return DiscreteCallback (c, as)
717
749
end
0 commit comments