@@ -387,6 +387,27 @@ function condition_header(sys::AbstractSystem, integrator = gensym(:MTKIntegrato
387387 end
388388end
389389
390+ function callback_save_header (sys:: AbstractSystem , cb)
391+ if ! (has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing )
392+ return (identity, identity)
393+ end
394+ save_idxs = get (ic. callback_to_clocks, cb, Int[])
395+ isempty (save_idxs) && return (identity, identity)
396+
397+ wrapper = function (expr)
398+ return Func (expr. args, [],
399+ LiteralExpr (quote
400+ $ (expr. body)
401+ save_idxs = $ (save_idxs)
402+ for idx in save_idxs
403+ $ (SciMLBase. save_discretes!)($ (expr. args[1 ]), idx)
404+ end
405+ end ))
406+ end
407+
408+ return wrapper, wrapper
409+ end
410+
390411"""
391412 compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps; expression, kwargs...)
392413
@@ -421,7 +442,7 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
421442end
422443
423444function compile_affect (cb:: SymbolicContinuousCallback , args... ; kwargs... )
424- compile_affect (affects (cb), args... ; kwargs... )
445+ compile_affect (affects (cb), cb, args... ; kwargs... )
425446end
426447
427448"""
@@ -441,7 +462,7 @@ Notes
441462 well-formed.
442463 - `kwargs` are passed through to `Symbolics.build_function`.
443464"""
444- function compile_affect (eqs:: Vector{Equation} , sys, dvs, ps; outputidxs = nothing ,
465+ function compile_affect (eqs:: Vector{Equation} , cb, sys, dvs, ps; outputidxs = nothing ,
445466 expression = Val{true }, checkvars = true , eval_expression = false ,
446467 eval_module = @__MODULE__ ,
447468 postprocess_affect_expr! = nothing , kwargs... )
@@ -497,7 +518,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
497518 integ = gensym (:MTKIntegrator )
498519 pre = get_preprocess_constants (rhss)
499520 rf_oop, rf_ip = build_function (rhss, u, p... , t; expression = Val{true },
500- wrap_code = add_integrator_header (sys, integ, outvar) .∘
521+ wrap_code = callback_save_header (sys, cb) .∘
522+ add_integrator_header (sys, integ, outvar) .∘
501523 wrap_array_vars (sys, rhss; dvs, ps = _ps) .∘
502524 wrap_parameter_dependencies (sys, false ),
503525 outputidxs = update_inds,
@@ -606,14 +628,14 @@ Compile a single continuous callback affect function(s).
606628function compile_affect_fn (cb, sys:: AbstractODESystem , dvs, ps, kwargs)
607629 eq_aff = affects (cb)
608630 eq_neg_aff = affect_negs (cb)
609- affect = compile_affect (eq_aff, sys, dvs, ps; expression = Val{false }, kwargs... )
631+ affect = compile_affect (eq_aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
610632 if eq_neg_aff === eq_aff
611633 affect_neg = affect
612634 elseif isnothing (eq_neg_aff)
613635 affect_neg = nothing
614636 else
615637 affect_neg = compile_affect (
616- eq_neg_aff, sys, dvs, ps; expression = Val{false }, kwargs... )
638+ eq_neg_aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
617639 end
618640 (affect = affect, affect_neg = affect_neg)
619641end
@@ -657,7 +679,7 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
657679 end
658680end
659681
660- function compile_user_affect (affect:: FunctionalAffect , sys, dvs, ps; kwargs... )
682+ function compile_user_affect (affect:: FunctionalAffect , cb, sys, dvs, ps; kwargs... )
661683 dvs_ind = Dict (reverse (en) for en in enumerate (dvs))
662684 v_inds = map (sym -> dvs_ind[sym], unknowns (affect))
663685
@@ -679,21 +701,31 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
679701 p = filter (x -> ! isnothing (x[2 ]), collect (zip (parameters_syms (affect), p_inds))) |>
680702 NamedTuple
681703
682- let u = u, p = p, user_affect = func (affect), ctx = context (affect)
704+ if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing
705+ save_idxs = get (ic. callback_to_clocks, cb, Int[])
706+ else
707+ save_idxs = Int[]
708+ end
709+ let u = u, p = p, user_affect = func (affect), ctx = context (affect),
710+ save_idxs = save_idxs
711+
683712 function (integ)
684713 user_affect (integ, u, p, ctx)
714+ for idx in save_idxs
715+ SciMLBase. save_discretes! (integ, idx)
716+ end
685717 end
686718 end
687719end
688720
689- function compile_affect (affect:: FunctionalAffect , sys, dvs, ps; kwargs... )
690- compile_user_affect (affect, sys, dvs, ps; kwargs... )
721+ function compile_affect (affect:: FunctionalAffect , cb, sys, dvs, ps; kwargs... )
722+ compile_user_affect (affect, cb, sys, dvs, ps; kwargs... )
691723end
692724
693725function generate_timed_callback (cb, sys, dvs, ps; postprocess_affect_expr! = nothing ,
694726 kwargs... )
695727 cond = condition (cb)
696- as = compile_affect (affects (cb), sys, dvs, ps; expression = Val{false },
728+ as = compile_affect (affects (cb), cb, sys, dvs, ps; expression = Val{false },
697729 postprocess_affect_expr!, kwargs... )
698730 if cond isa AbstractVector
699731 # Preset Time
@@ -711,7 +743,7 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
711743 kwargs... )
712744 else
713745 c = compile_condition (cb, sys, dvs, ps; expression = Val{false }, kwargs... )
714- as = compile_affect (affects (cb), sys, dvs, ps; expression = Val{false },
746+ as = compile_affect (affects (cb), cb, sys, dvs, ps; expression = Val{false },
715747 postprocess_affect_expr!, kwargs... )
716748 return DiscreteCallback (c, as)
717749 end
0 commit comments